From c44fd0c867939ecdd5089ae0fcc5879896c18151 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 13 Feb 2023 16:53:26 +0800 Subject: [PATCH 01/30] [workflow] added trigger to build doc upon release (#2678) --- .github/workflows/doc_build_after_merge.yml | 28 +++++++++++++++++++ ...heck_doc_on_pr.yml => doc_check_on_pr.yml} | 0 2 files changed, 28 insertions(+) create mode 100644 .github/workflows/doc_build_after_merge.yml rename .github/workflows/{check_doc_on_pr.yml => doc_check_on_pr.yml} (100%) diff --git a/.github/workflows/doc_build_after_merge.yml b/.github/workflows/doc_build_after_merge.yml new file mode 100644 index 000000000..dae3b70e1 --- /dev/null +++ b/.github/workflows/doc_build_after_merge.yml @@ -0,0 +1,28 @@ +name: Build Documentation upon Release + +on: + workflow_dispatch: + pull_request: + paths: + - 'version.txt' + types: + - closed + +jobs: + build-doc: + name: Trigger Documentation Build Workflow + if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI' + runs-on: ubuntu-latest + steps: + - name: trigger workflow in ColossalAI-Documentation + run: | + gh + curl \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${GH_TOKEN}"\ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/hpcaitech/ColossalAI-Documentation/actions/workflows/deploy.yml/dispatches \ + -d '{"ref":"main"}' + env: + GH_TOKEN: ${{secrets.DOC_REPO_TOKEN}} diff --git a/.github/workflows/check_doc_on_pr.yml b/.github/workflows/doc_check_on_pr.yml similarity index 100% rename from .github/workflows/check_doc_on_pr.yml rename to .github/workflows/doc_check_on_pr.yml From 5cd8cae0c9db7ae4bfb98e42ec9c7ab2e6ebd17f Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 13 Feb 2023 17:04:49 +0800 Subject: [PATCH 02/30] [workflow] fixed communtity report ranking (#2680) --- .../scripts/generate_leaderboard_and_send_to_lark.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py index 36cdd9518..16b8957c1 100644 --- a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py +++ b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py @@ -292,7 +292,13 @@ def generate_user_engagement_leaderboard_image(github_token: str, output_path: s y = [] if len(total_engagement_count) > 0: + ranking = [] for name, count in total_engagement_count.items(): + ranking.append((name, count)) + + ranking.sort(key=lambda x: x[1], reverse=True) + + for name, count in ranking: x.append(count) y.append(name) From f0aa191f51704806e65ac849da137069bb35a6d5 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 13 Feb 2023 17:53:15 +0800 Subject: [PATCH 03/30] [gemini] fix colo_init_context (#2683) --- colossalai/utils/model/colo_init_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index ab354ea70..87ae413a2 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -32,7 +32,7 @@ def _convert_to_coloparam(param: torch.nn.Parameter, default_pg: Optional[ProcessGroup] = None, default_dist_spec: Optional[Any] = None) -> ColoParameter: - if isinstance(param, ColoParameter): + if type(param) is ColoParameter: return param # detaching tensor is necessary for optimizers. requires_grad = param.requires_grad @@ -102,7 +102,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): """ name_list = [] for name, param in _named_params_with_replica(module): - if isinstance(param, ColoTensor): + if type(param) is ColoParameter: continue split = name.rfind('.') From df4f020ee32c5b857d2a9806d5ec40d0b2064021 Mon Sep 17 00:00:00 2001 From: HELSON Date: Mon, 13 Feb 2023 18:00:16 +0800 Subject: [PATCH 04/30] [zero1&2] only append parameters with gradients (#2681) --- colossalai/zero/sharded_optim/low_level_optim.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py index d174fc6ac..89f5f9fad 100644 --- a/colossalai/zero/sharded_optim/low_level_optim.py +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -131,7 +131,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): - group_params = param_group['params'] + group_params = list() + for param in param_group['params']: + if param.requires_grad: + group_params.append(param) # add the fp16 params to fp16_param_groups for bookkeeping self._fp16_param_groups[group_id] = group_params From 88416019e7d55dd4c28e29b9cdded60f71153964 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 13 Feb 2023 18:10:54 +0800 Subject: [PATCH 05/30] Automated submodule synchronization (#2648) Co-authored-by: github-actions --- examples/tutorial/fastfold/FastFold | 2 +- inference | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/tutorial/fastfold/FastFold b/examples/tutorial/fastfold/FastFold index 95150c384..0188361b6 160000 --- a/examples/tutorial/fastfold/FastFold +++ b/examples/tutorial/fastfold/FastFold @@ -1 +1 @@ -Subproject commit 95150c384b9b6e776cad38dd91494e74115dc4ac +Subproject commit 0188361b6e2b46bca61d37af5674eacf7ca9947f diff --git a/inference b/inference index 5d250f4af..cde4c8f4e 160000 --- a/inference +++ b/inference @@ -1 +1 @@ -Subproject commit 5d250f4af6283f65a701636628ffeef10447e650 +Subproject commit cde4c8f4e7269decb82b1b225ada278694e10f6a From 46f20bac4109c29f7a346fa6f62ee8fb66799dc5 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Mon, 13 Feb 2023 23:05:29 +0800 Subject: [PATCH 06/30] [doc] update auto parallel paper link (#2686) * [doc] update auto parallel paper link * [doc] update auto parallel paper link --- README-zh-Hans.md | 2 +- README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README-zh-Hans.md b/README-zh-Hans.md index 023b5a21c..4b0ba9c42 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -102,7 +102,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 - 1维, [2维](https://arxiv.org/abs/2104.05343), [2.5维](https://arxiv.org/abs/2105.14500), [3维](https://arxiv.org/abs/2105.14450) 张量并行 - [序列并行](https://arxiv.org/abs/2105.13120) - [零冗余优化器 (ZeRO)](https://arxiv.org/abs/1910.02054) - - [自动并行](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt) + - [自动并行](https://arxiv.org/abs/2302.02599) - 异构内存管理 - [PatrickStar](https://arxiv.org/abs/2108.05818) - 使用友好 diff --git a/README.md b/README.md index 6ad736a43..703e3f3bf 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ distributed training and inference in a few lines. - 1D, [2D](https://arxiv.org/abs/2104.05343), [2.5D](https://arxiv.org/abs/2105.14500), [3D](https://arxiv.org/abs/2105.14450) Tensor Parallelism - [Sequence Parallelism](https://arxiv.org/abs/2105.13120) - [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054) - - [Auto-Parallelism](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt) + - [Auto-Parallelism](https://arxiv.org/abs/2302.02599) - Heterogeneous Memory Management - [PatrickStar](https://arxiv.org/abs/2108.05818) From c3abdd085d3daa33d7a026b610651cc06fa3a246 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 14 Feb 2023 19:37:14 +0800 Subject: [PATCH 07/30] [release] update version (#2691) --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 717903969..abd410582 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.3 +0.2.4 From 1b34701027f4654bd5c543330b8969c0b001c68c Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 14 Feb 2023 22:17:25 +0800 Subject: [PATCH 08/30] [app] add chatgpt application (#2698) --- applications/ChatGPT/.gitignore | 146 +++++++++++++ applications/ChatGPT/LICENSE | 202 ++++++++++++++++++ applications/ChatGPT/README.md | 80 +++++++ applications/ChatGPT/benchmarks/README.md | 94 ++++++++ .../ChatGPT/benchmarks/benchmark_gpt_dummy.py | 183 ++++++++++++++++ .../ChatGPT/benchmarks/benchmark_gpt_dummy.sh | 45 ++++ .../benchmarks/benchmark_opt_lora_dummy.py | 178 +++++++++++++++ applications/ChatGPT/chatgpt/__init__.py | 0 .../ChatGPT/chatgpt/dataset/__init__.py | 3 + .../ChatGPT/chatgpt/dataset/reward_dataset.py | 52 +++++ .../chatgpt/experience_maker/__init__.py | 4 + .../ChatGPT/chatgpt/experience_maker/base.py | 77 +++++++ .../ChatGPT/chatgpt/experience_maker/naive.py | 36 ++++ applications/ChatGPT/chatgpt/nn/__init__.py | 18 ++ applications/ChatGPT/chatgpt/nn/actor.py | 62 ++++++ .../ChatGPT/chatgpt/nn/bloom_actor.py | 35 +++ .../ChatGPT/chatgpt/nn/bloom_critic.py | 37 ++++ applications/ChatGPT/chatgpt/nn/bloom_rm.py | 37 ++++ applications/ChatGPT/chatgpt/nn/critic.py | 47 ++++ applications/ChatGPT/chatgpt/nn/generation.py | 137 ++++++++++++ .../ChatGPT/chatgpt/nn/generation_utils.py | 92 ++++++++ applications/ChatGPT/chatgpt/nn/gpt_actor.py | 31 +++ applications/ChatGPT/chatgpt/nn/gpt_critic.py | 33 +++ applications/ChatGPT/chatgpt/nn/gpt_rm.py | 33 +++ applications/ChatGPT/chatgpt/nn/lora.py | 127 +++++++++++ applications/ChatGPT/chatgpt/nn/loss.py | 105 +++++++++ applications/ChatGPT/chatgpt/nn/opt_actor.py | 35 +++ applications/ChatGPT/chatgpt/nn/opt_critic.py | 37 ++++ applications/ChatGPT/chatgpt/nn/opt_rm.py | 33 +++ .../ChatGPT/chatgpt/nn/reward_model.py | 41 ++++ applications/ChatGPT/chatgpt/nn/utils.py | 92 ++++++++ .../ChatGPT/chatgpt/replay_buffer/__init__.py | 4 + .../ChatGPT/chatgpt/replay_buffer/base.py | 43 ++++ .../ChatGPT/chatgpt/replay_buffer/naive.py | 57 +++++ .../ChatGPT/chatgpt/replay_buffer/utils.py | 73 +++++++ .../ChatGPT/chatgpt/trainer/__init__.py | 5 + applications/ChatGPT/chatgpt/trainer/base.py | 162 ++++++++++++++ .../chatgpt/trainer/callbacks/__init__.py | 4 + .../ChatGPT/chatgpt/trainer/callbacks/base.py | 39 ++++ .../callbacks/performance_evaluator.py | 133 ++++++++++++ applications/ChatGPT/chatgpt/trainer/ppo.py | 104 +++++++++ applications/ChatGPT/chatgpt/trainer/rm.py | 77 +++++++ .../chatgpt/trainer/strategies/__init__.py | 6 + .../chatgpt/trainer/strategies/base.py | 45 ++++ .../chatgpt/trainer/strategies/colossalai.py | 125 +++++++++++ .../ChatGPT/chatgpt/trainer/strategies/ddp.py | 59 +++++ .../chatgpt/trainer/strategies/naive.py | 36 ++++ applications/ChatGPT/chatgpt/trainer/utils.py | 5 + applications/ChatGPT/examples/README.md | 105 +++++++++ .../ChatGPT/examples/requirements.txt | 1 + applications/ChatGPT/examples/test_ci.sh | 27 +++ applications/ChatGPT/examples/train_dummy.py | 121 +++++++++++ applications/ChatGPT/examples/train_dummy.sh | 18 ++ .../ChatGPT/examples/train_prompts.py | 113 ++++++++++ .../ChatGPT/examples/train_prompts.sh | 18 ++ .../ChatGPT/examples/train_reward_model.py | 53 +++++ applications/ChatGPT/examples/train_rm.sh | 18 ++ applications/ChatGPT/pytest.ini | 6 + .../requirements/requirements-test.txt | 1 + .../ChatGPT/requirements/requirements.txt | 6 + applications/ChatGPT/setup.py | 42 ++++ applications/ChatGPT/tests/__init__.py | 0 applications/ChatGPT/tests/test_data.py | 117 ++++++++++ applications/ChatGPT/version.txt | 1 + 64 files changed, 3756 insertions(+) create mode 100644 applications/ChatGPT/.gitignore create mode 100644 applications/ChatGPT/LICENSE create mode 100644 applications/ChatGPT/README.md create mode 100644 applications/ChatGPT/benchmarks/README.md create mode 100644 applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py create mode 100755 applications/ChatGPT/benchmarks/benchmark_gpt_dummy.sh create mode 100644 applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py create mode 100644 applications/ChatGPT/chatgpt/__init__.py create mode 100644 applications/ChatGPT/chatgpt/dataset/__init__.py create mode 100644 applications/ChatGPT/chatgpt/dataset/reward_dataset.py create mode 100644 applications/ChatGPT/chatgpt/experience_maker/__init__.py create mode 100644 applications/ChatGPT/chatgpt/experience_maker/base.py create mode 100644 applications/ChatGPT/chatgpt/experience_maker/naive.py create mode 100644 applications/ChatGPT/chatgpt/nn/__init__.py create mode 100644 applications/ChatGPT/chatgpt/nn/actor.py create mode 100644 applications/ChatGPT/chatgpt/nn/bloom_actor.py create mode 100644 applications/ChatGPT/chatgpt/nn/bloom_critic.py create mode 100644 applications/ChatGPT/chatgpt/nn/bloom_rm.py create mode 100644 applications/ChatGPT/chatgpt/nn/critic.py create mode 100644 applications/ChatGPT/chatgpt/nn/generation.py create mode 100644 applications/ChatGPT/chatgpt/nn/generation_utils.py create mode 100644 applications/ChatGPT/chatgpt/nn/gpt_actor.py create mode 100644 applications/ChatGPT/chatgpt/nn/gpt_critic.py create mode 100644 applications/ChatGPT/chatgpt/nn/gpt_rm.py create mode 100644 applications/ChatGPT/chatgpt/nn/lora.py create mode 100644 applications/ChatGPT/chatgpt/nn/loss.py create mode 100644 applications/ChatGPT/chatgpt/nn/opt_actor.py create mode 100644 applications/ChatGPT/chatgpt/nn/opt_critic.py create mode 100644 applications/ChatGPT/chatgpt/nn/opt_rm.py create mode 100644 applications/ChatGPT/chatgpt/nn/reward_model.py create mode 100644 applications/ChatGPT/chatgpt/nn/utils.py create mode 100644 applications/ChatGPT/chatgpt/replay_buffer/__init__.py create mode 100644 applications/ChatGPT/chatgpt/replay_buffer/base.py create mode 100644 applications/ChatGPT/chatgpt/replay_buffer/naive.py create mode 100644 applications/ChatGPT/chatgpt/replay_buffer/utils.py create mode 100644 applications/ChatGPT/chatgpt/trainer/__init__.py create mode 100644 applications/ChatGPT/chatgpt/trainer/base.py create mode 100644 applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py create mode 100644 applications/ChatGPT/chatgpt/trainer/callbacks/base.py create mode 100644 applications/ChatGPT/chatgpt/trainer/callbacks/performance_evaluator.py create mode 100644 applications/ChatGPT/chatgpt/trainer/ppo.py create mode 100644 applications/ChatGPT/chatgpt/trainer/rm.py create mode 100644 applications/ChatGPT/chatgpt/trainer/strategies/__init__.py create mode 100644 applications/ChatGPT/chatgpt/trainer/strategies/base.py create mode 100644 applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py create mode 100644 applications/ChatGPT/chatgpt/trainer/strategies/ddp.py create mode 100644 applications/ChatGPT/chatgpt/trainer/strategies/naive.py create mode 100644 applications/ChatGPT/chatgpt/trainer/utils.py create mode 100644 applications/ChatGPT/examples/README.md create mode 100644 applications/ChatGPT/examples/requirements.txt create mode 100755 applications/ChatGPT/examples/test_ci.sh create mode 100644 applications/ChatGPT/examples/train_dummy.py create mode 100755 applications/ChatGPT/examples/train_dummy.sh create mode 100644 applications/ChatGPT/examples/train_prompts.py create mode 100755 applications/ChatGPT/examples/train_prompts.sh create mode 100644 applications/ChatGPT/examples/train_reward_model.py create mode 100755 applications/ChatGPT/examples/train_rm.sh create mode 100644 applications/ChatGPT/pytest.ini create mode 100644 applications/ChatGPT/requirements/requirements-test.txt create mode 100644 applications/ChatGPT/requirements/requirements.txt create mode 100644 applications/ChatGPT/setup.py create mode 100644 applications/ChatGPT/tests/__init__.py create mode 100644 applications/ChatGPT/tests/test_data.py create mode 100644 applications/ChatGPT/version.txt diff --git a/applications/ChatGPT/.gitignore b/applications/ChatGPT/.gitignore new file mode 100644 index 000000000..40f3f6deb --- /dev/null +++ b/applications/ChatGPT/.gitignore @@ -0,0 +1,146 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/.build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDE +.idea/ +.vscode/ + +# macos +*.DS_Store +#data/ + +docs/.build + +# pytorch checkpoint +*.pt + +# ignore version.py generated by setup.py +colossalai/version.py diff --git a/applications/ChatGPT/LICENSE b/applications/ChatGPT/LICENSE new file mode 100644 index 000000000..0528c89ea --- /dev/null +++ b/applications/ChatGPT/LICENSE @@ -0,0 +1,202 @@ +Copyright 2021- HPC-AI Technology Inc. All rights reserved. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021- HPC-AI Technology Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/applications/ChatGPT/README.md b/applications/ChatGPT/README.md new file mode 100644 index 000000000..dce59ad4b --- /dev/null +++ b/applications/ChatGPT/README.md @@ -0,0 +1,80 @@ +# RLHF - ColossalAI + +Implementation of RLHF (Reinforcement Learning with Human Feedback) powered by ColossalAI. It supports distributed training and offloading, which can fit extremly large models. + +

+ +

+ +## Training process (step 3) +

+ +

+

+ +

+ + +## Install +```shell +pip install . +``` + + +## Usage + +The main entrypoint is `Trainer`. We only support PPO trainer now. We support many training strategies: + +- NaiveStrategy: simplest strategy. Train on single GPU. +- DDPStrategy: use `torch.nn.parallel.DistributedDataParallel`. Train on multi GPUs. +- ColossalAIStrategy: use Gemini and Zero of ColossalAI. It eliminates model duplication on each GPU and supports offload. It's very useful when training large models on multi GPUs. + +Simplest usage: + +```python +from chatgpt.trainer import PPOTrainer +from chatgpt.trainer.strategies import ColossalAIStrategy + +strategy = ColossalAIStrategy() + +with strategy.model_init_context(): + # init your model here + actor = Actor() + critic = Critic() + +trainer = PPOTrainer(actor = actor, critic= critic, strategy, ...) + +trainer.fit(dataset, ...) +``` + +For more details, see `examples/`. + +We also support training reward model with true-world data. See `examples/train_reward_model.py`. + +## Todo + +- [x] implement PPO training +- [x] implement training reward model +- [x] support LoRA +- [ ] implement PPO-ptx fine-tuning +- [ ] integrate with Ray +- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL) + +## Citations + +```bibtex +@article{Hu2021LoRALA, + title = {LoRA: Low-Rank Adaptation of Large Language Models}, + author = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen}, + journal = {ArXiv}, + year = {2021}, + volume = {abs/2106.09685} +} + +@article{ouyang2022training, + title={Training language models to follow instructions with human feedback}, + author={Ouyang, Long and Wu, Jeff and Jiang, Xu and Almeida, Diogo and Wainwright, Carroll L and Mishkin, Pamela and Zhang, Chong and Agarwal, Sandhini and Slama, Katarina and Ray, Alex and others}, + journal={arXiv preprint arXiv:2203.02155}, + year={2022} +} +``` diff --git a/applications/ChatGPT/benchmarks/README.md b/applications/ChatGPT/benchmarks/README.md new file mode 100644 index 000000000..f7212fc89 --- /dev/null +++ b/applications/ChatGPT/benchmarks/README.md @@ -0,0 +1,94 @@ +# Benchmarks + +## Benchmark GPT on dummy prompt data + +We provide various GPT models (string in parentheses is the corresponding model name used in this script): + +- GPT2-S (s) +- GPT2-M (m) +- GPT2-L (l) +- GPT2-XL (xl) +- GPT2-4B (4b) +- GPT2-6B (6b) +- GPT2-8B (8b) +- GPT2-10B (10b) +- GPT2-12B (12b) +- GPT2-15B (15b) +- GPT2-18B (18b) +- GPT2-20B (20b) +- GPT2-24B (24b) +- GPT2-28B (28b) +- GPT2-32B (32b) +- GPT2-36B (36b) +- GPT2-40B (40b) +- GPT3 (175b) + +We also provide various training strategies: + +- ddp: torch DDP +- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3 +- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload +- colossalai_zero2: ColossalAI zero2 +- colossalai_zero2_cpu: ColossalAI zero2-offload +- colossalai_zero1: ColossalAI zero1 +- colossalai_zero1_cpu: ColossalAI zero1-offload + +We only support `torchrun` to launch now. E.g. + +```shell +# run GPT2-S on single-node single-GPU with min batch size +torchrun --standalone --nproc_pero_node 1 benchmark_gpt_dummy.py --model s --strategy ddp --experience_batch_size 1 --train_batch_size 1 +# run GPT2-XL on single-node 4-GPU +torchrun --standalone --nproc_per_node 4 benchmark_gpt_dummy.py --model xl --strategy colossalai_zero2 +# run GPT3 on 8-node 8-GPU +torchrun --nnodes 8 --nproc_per_node 8 \ + --rdzv_id=$JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$HOST_NODE_ADDR \ + benchmark_gpt_dummy.py --model 175b --strategy colossalai_gemini +``` + +> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU. + +In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic. + +We also provide a simple shell script to run a set of benchmarks. But it only supports benchmark on single node. However, it's easy to run on multi-nodes by modifying launch command in this script. + +Usage: + +```shell +# run for GPUS=(1 2 4 8) x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256) +./benchmark_gpt_dummy.sh +# run for GPUS=2 x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256) +./benchmark_gpt_dummy.sh 2 +# run for GPUS=2 x strategy=ddp x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256) +./benchmark_gpt_dummy.sh 2 ddp +# run for GPUS=2 x strategy=ddp x model=l x batch_size=(1 2 4 8 16 32 64 128 256) +./benchmark_gpt_dummy.sh 2 ddp l +``` + +## Benchmark OPT with LoRA on dummy prompt data + +We provide various OPT models (string in parentheses is the corresponding model name used in this script): + +- OPT-125M (125m) +- OPT-350M (350m) +- OPT-700M (700m) +- OPT-1.3B (1.3b) +- OPT-2.7B (2.7b) +- OPT-3.5B (3.5b) +- OPT-5.5B (5.5b) +- OPT-6.7B (6.7b) +- OPT-10B (10b) +- OPT-13B (13b) + +We only support `torchrun` to launch now. E.g. + +```shell +# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size +torchrun --standalone --nproc_pero_node 1 benchmark_opt_lora_dummy.py --model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0 +# run OPT-350M with lora_rank=4 on single-node 4-GPU +torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 350m --strategy colossalai_zero2 --lora_rank 4 +``` + +> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU. + +In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic. diff --git a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py new file mode 100644 index 000000000..8474f3ba7 --- /dev/null +++ b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py @@ -0,0 +1,183 @@ +import argparse +from copy import deepcopy + +import torch +import torch.distributed as dist +import torch.nn as nn +from chatgpt.nn import GPTActor, GPTCritic, RewardModel +from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn +from chatgpt.trainer import PPOTrainer +from chatgpt.trainer.callbacks import PerformanceEvaluator +from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy +from torch.optim import Adam +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + + +def get_model_numel(model: nn.Module, strategy: Strategy) -> int: + numel = sum(p.numel() for p in model.parameters()) + if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: + numel *= dist.get_world_size() + return numel + + +def preprocess_batch(samples) -> dict: + input_ids = torch.stack(samples) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + return {'input_ids': input_ids, 'attention_mask': attention_mask} + + +def print_rank_0(*args, **kwargs) -> None: + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def print_model_numel(model_dict: dict) -> None: + B = 1024**3 + M = 1024**2 + K = 1024 + outputs = '' + for name, numel in model_dict.items(): + outputs += f'{name}: ' + if numel >= B: + outputs += f'{numel / B:.2f} B\n' + elif numel >= M: + outputs += f'{numel / M:.2f} M\n' + elif numel >= K: + outputs += f'{numel / K:.2f} K\n' + else: + outputs += f'{numel}\n' + print_rank_0(outputs) + + +def get_gpt_config(model_name: str) -> GPT2Config: + model_map = { + 's': GPT2Config(), + 'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16), + 'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20), + 'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25), + '2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16), + '4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16), + '6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16), + '8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16), + '10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16), + '12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16), + '15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16), + '18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16), + '20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16), + '24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16), + '28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16), + '32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16), + '36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16), + '40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16), + '175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96), + } + try: + return model_map[model_name] + except KeyError: + raise ValueError(f'Unknown model "{model_name}"') + + +def main(args): + if args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif args.strategy == 'colossalai_gemini_cpu': + strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2_cpu': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + elif args.strategy == 'colossalai_zero1': + strategy = ColossalAIStrategy(stage=1, placement_policy='cuda') + elif args.strategy == 'colossalai_zero1_cpu': + strategy = ColossalAIStrategy(stage=1, placement_policy='cpu') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + model_config = get_gpt_config(args.model) + + with strategy.model_init_context(): + actor = GPTActor(config=model_config).cuda() + critic = GPTCritic(config=model_config).cuda() + + initial_model = deepcopy(actor).cuda() + reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() + + actor_numel = get_model_numel(actor, strategy) + critic_numel = get_model_numel(critic, strategy) + initial_model_numel = get_model_numel(initial_model, strategy) + reward_model_numel = get_model_numel(reward_model, strategy) + print_model_numel({ + 'Actor': actor_numel, + 'Critic': critic_numel, + 'Initial model': initial_model_numel, + 'Reward model': reward_model_numel + }) + performance_evaluator = PerformanceEvaluator(actor_numel, + critic_numel, + initial_model_numel, + reward_model_numel, + enable_grad_checkpoint=False, + ignore_episodes=1) + + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=5e-6) + critic_optim = HybridAdam(critic.parameters(), lr=5e-6) + else: + actor_optim = Adam(actor.parameters(), lr=5e-6) + critic_optim = Adam(critic.parameters(), lr=5e-6) + + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + + trainer = PPOTrainer(strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + experience_batch_size=args.experience_batch_size, + tokenizer=preprocess_batch, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + prepare_inputs_fn=gpt_prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn, + callbacks=[performance_evaluator]) + + random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) + trainer.fit(random_prompts, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='s') + parser.add_argument('--strategy', + choices=[ + 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', + 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu' + ], + default='ddp') + parser.add_argument('--num_episodes', type=int, default=3) + parser.add_argument('--max_timesteps', type=int, default=8) + parser.add_argument('--update_timesteps', type=int, default=8) + parser.add_argument('--max_epochs', type=int, default=3) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + args = parser.parse_args() + main(args) diff --git a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.sh b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.sh new file mode 100755 index 000000000..d70f88725 --- /dev/null +++ b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +# Usage: $0 +set -xu + +BASE=$(realpath $(dirname $0)) + + +PY_SCRIPT=${BASE}/benchmark_gpt_dummy.py +export OMP_NUM_THREADS=8 + +function tune_batch_size() { + # we found when experience batch size is equal to train batch size + # peak CUDA memory usage of making experience phase is less than or equal to that of training phase + # thus, experience batch size can be larger than or equal to train batch size + for bs in 1 2 4 8 16 32 64 128 256; do + torchrun --standalone --nproc_per_node $1 $PY_SCRIPT --model $2 --strategy $3 --experience_batch_size $bs --train_batch_size $bs || return 1 + done +} + +if [ $# -eq 0 ]; then + num_gpus=(1 2 4 8) +else + num_gpus=($1) +fi + +if [ $# -le 1 ]; then + strategies=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") +else + strategies=($2) +fi + +if [ $# -le 2 ]; then + models=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") +else + models=($3) +fi + + +for num_gpu in ${num_gpus[@]}; do + for strategy in ${strategies[@]}; do + for model in ${models[@]}; do + tune_batch_size $num_gpu $model $strategy || break + done + done +done diff --git a/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py new file mode 100644 index 000000000..accbc4155 --- /dev/null +++ b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py @@ -0,0 +1,178 @@ +import argparse +from copy import deepcopy + +import torch +import torch.distributed as dist +import torch.nn as nn +from chatgpt.nn import OPTActor, OPTCritic, RewardModel +from chatgpt.nn.generation_utils import opt_prepare_inputs_fn, update_model_kwargs_fn +from chatgpt.trainer import PPOTrainer +from chatgpt.trainer.callbacks import PerformanceEvaluator +from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy +from torch.optim import Adam +from transformers import AutoTokenizer +from transformers.models.opt.configuration_opt import OPTConfig + +from colossalai.nn.optimizer import HybridAdam + + +def get_model_numel(model: nn.Module, strategy: Strategy) -> int: + numel = sum(p.numel() for p in model.parameters()) + if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: + numel *= dist.get_world_size() + return numel + + +def preprocess_batch(samples) -> dict: + input_ids = torch.stack(samples) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + return {'input_ids': input_ids, 'attention_mask': attention_mask} + + +def print_rank_0(*args, **kwargs) -> None: + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def print_model_numel(model_dict: dict) -> None: + B = 1024**3 + M = 1024**2 + K = 1024 + outputs = '' + for name, numel in model_dict.items(): + outputs += f'{name}: ' + if numel >= B: + outputs += f'{numel / B:.2f} B\n' + elif numel >= M: + outputs += f'{numel / M:.2f} M\n' + elif numel >= K: + outputs += f'{numel / K:.2f} K\n' + else: + outputs += f'{numel}\n' + print_rank_0(outputs) + + +def get_gpt_config(model_name: str) -> OPTConfig: + model_map = { + '125m': OPTConfig.from_pretrained('facebook/opt-125m'), + '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), + '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), + '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'), + '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'), + '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), + '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), + '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'), + '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), + '13b': OPTConfig.from_pretrained('facebook/opt-13b'), + } + try: + return model_map[model_name] + except KeyError: + raise ValueError(f'Unknown model "{model_name}"') + + +def main(args): + if args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif args.strategy == 'colossalai_gemini_cpu': + strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2_cpu': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + elif args.strategy == 'colossalai_zero1': + strategy = ColossalAIStrategy(stage=1, placement_policy='cuda') + elif args.strategy == 'colossalai_zero1_cpu': + strategy = ColossalAIStrategy(stage=1, placement_policy='cpu') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac) + + model_config = get_gpt_config(args.model) + + with strategy.model_init_context(): + actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda() + critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda() + + initial_model = deepcopy(actor).cuda() + reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() + + actor_numel = get_model_numel(actor, strategy) + critic_numel = get_model_numel(critic, strategy) + initial_model_numel = get_model_numel(initial_model, strategy) + reward_model_numel = get_model_numel(reward_model, strategy) + print_model_numel({ + 'Actor': actor_numel, + 'Critic': critic_numel, + 'Initial model': initial_model_numel, + 'Reward model': reward_model_numel + }) + performance_evaluator = PerformanceEvaluator(actor_numel, + critic_numel, + initial_model_numel, + reward_model_numel, + enable_grad_checkpoint=False, + ignore_episodes=1) + + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=5e-6) + critic_optim = HybridAdam(critic.parameters(), lr=5e-6) + else: + actor_optim = Adam(actor.parameters(), lr=5e-6) + critic_optim = Adam(critic.parameters(), lr=5e-6) + + tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + tokenizer.pad_token = tokenizer.eos_token + + trainer = PPOTrainer(strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + experience_batch_size=args.experience_batch_size, + tokenizer=preprocess_batch, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + prepare_inputs_fn=opt_prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn, + callbacks=[performance_evaluator]) + + random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) + trainer.fit(random_prompts, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='125m') + parser.add_argument('--strategy', + choices=[ + 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', + 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu' + ], + default='ddp') + parser.add_argument('--num_episodes', type=int, default=3) + parser.add_argument('--max_timesteps', type=int, default=8) + parser.add_argument('--update_timesteps', type=int, default=8) + parser.add_argument('--max_epochs', type=int, default=3) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=4) + parser.add_argument('--cuda_mem_frac', type=float, default=1.0) + args = parser.parse_args() + main(args) diff --git a/applications/ChatGPT/chatgpt/__init__.py b/applications/ChatGPT/chatgpt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/ChatGPT/chatgpt/dataset/__init__.py b/applications/ChatGPT/chatgpt/dataset/__init__.py new file mode 100644 index 000000000..2f330ee67 --- /dev/null +++ b/applications/ChatGPT/chatgpt/dataset/__init__.py @@ -0,0 +1,3 @@ +from .reward_dataset import RewardDataset + +__all__ = ['RewardDataset'] diff --git a/applications/ChatGPT/chatgpt/dataset/reward_dataset.py b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py new file mode 100644 index 000000000..14edcce30 --- /dev/null +++ b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py @@ -0,0 +1,52 @@ +from typing import Callable + +from torch.utils.data import Dataset +from tqdm import tqdm + + +class RewardDataset(Dataset): + """ + Dataset for reward model + + Args: + dataset: dataset for reward model + tokenizer: tokenizer for reward model + max_length: max length of input + """ + + def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None: + super().__init__() + self.chosen = [] + self.reject = [] + for data in tqdm(dataset): + prompt = data['prompt'] + + chosen = prompt + data['chosen'] + "<|endoftext|>" + chosen_token = tokenizer(chosen, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.chosen.append({ + "input_ids": chosen_token['input_ids'], + "attention_mask": chosen_token['attention_mask'] + }) + + reject = prompt + data['rejected'] + "<|endoftext|>" + reject_token = tokenizer(reject, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.reject.append({ + "input_ids": reject_token['input_ids'], + "attention_mask": reject_token['attention_mask'] + }) + + def __len__(self): + length = len(self.chosen) + return length + + def __getitem__(self, idx): + return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ + "input_ids"], self.reject[idx]["attention_mask"] diff --git a/applications/ChatGPT/chatgpt/experience_maker/__init__.py b/applications/ChatGPT/chatgpt/experience_maker/__init__.py new file mode 100644 index 000000000..39ca7576b --- /dev/null +++ b/applications/ChatGPT/chatgpt/experience_maker/__init__.py @@ -0,0 +1,4 @@ +from .base import Experience, ExperienceMaker +from .naive import NaiveExperienceMaker + +__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker'] diff --git a/applications/ChatGPT/chatgpt/experience_maker/base.py b/applications/ChatGPT/chatgpt/experience_maker/base.py new file mode 100644 index 000000000..61895322c --- /dev/null +++ b/applications/ChatGPT/chatgpt/experience_maker/base.py @@ -0,0 +1,77 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from chatgpt.nn.actor import Actor + + +@dataclass +class Experience: + """Experience is a batch of data. + These data should have the the sequence length and number of actions. + Left padding for sequences is applied. + + Shapes of each tensor: + sequences: (B, S) + action_log_probs: (B, A) + values: (B) + reward: (B) + advatanges: (B) + attention_mask: (B, S) + action_mask: (B, A) + + "A" is the number of actions. + """ + sequences: torch.Tensor + action_log_probs: torch.Tensor + values: torch.Tensor + reward: torch.Tensor + advantages: torch.Tensor + attention_mask: Optional[torch.LongTensor] + action_mask: Optional[torch.BoolTensor] + + @torch.no_grad() + def to_device(self, device: torch.device) -> None: + self.sequences = self.sequences.to(device) + self.action_log_probs = self.action_log_probs.to(device) + self.values = self.values.to(device) + self.reward = self.reward.to(device) + self.advantages = self.advantages.to(device) + if self.attention_mask is not None: + self.attention_mask = self.attention_mask.to(device) + if self.action_mask is not None: + self.action_mask = self.action_mask.to(device) + + def pin_memory(self): + self.sequences = self.sequences.pin_memory() + self.action_log_probs = self.action_log_probs.pin_memory() + self.values = self.values.pin_memory() + self.reward = self.reward.pin_memory() + self.advantages = self.advantages.pin_memory() + if self.attention_mask is not None: + self.attention_mask = self.attention_mask.pin_memory() + if self.action_mask is not None: + self.action_mask = self.action_mask.pin_memory() + return self + + +class ExperienceMaker(ABC): + + def __init__(self, + actor: Actor, + critic: nn.Module, + reward_model: nn.Module, + initial_model: Actor, + kl_coef: float = 0.1) -> None: + super().__init__() + self.actor = actor + self.critic = critic + self.reward_model = reward_model + self.initial_model = initial_model + self.kl_coef = kl_coef + + @abstractmethod + def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: + pass diff --git a/applications/ChatGPT/chatgpt/experience_maker/naive.py b/applications/ChatGPT/chatgpt/experience_maker/naive.py new file mode 100644 index 000000000..f4fd2078c --- /dev/null +++ b/applications/ChatGPT/chatgpt/experience_maker/naive.py @@ -0,0 +1,36 @@ +import torch +from chatgpt.nn.utils import compute_reward, normalize + +from .base import Experience, ExperienceMaker + + +class NaiveExperienceMaker(ExperienceMaker): + """ + Naive experience maker. + """ + + @torch.no_grad() + def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: + self.actor.eval() + self.critic.eval() + self.initial_model.eval() + self.reward_model.eval() + + sequences, attention_mask, action_mask = self.actor.generate(input_ids, + return_action_mask=True, + **generate_kwargs) + num_actions = action_mask.size(1) + + action_log_probs = self.actor(sequences, num_actions, attention_mask) + base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask) + value = self.critic(sequences, action_mask, attention_mask) + r = self.reward_model(sequences, attention_mask) + + reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) + + advantage = reward - value + # TODO(ver217): maybe normalize adv + if advantage.ndim == 1: + advantage = advantage.unsqueeze(-1) + + return Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask) diff --git a/applications/ChatGPT/chatgpt/nn/__init__.py b/applications/ChatGPT/chatgpt/nn/__init__.py new file mode 100644 index 000000000..c728d7df3 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/__init__.py @@ -0,0 +1,18 @@ +from .actor import Actor +from .bloom_actor import BLOOMActor +from .bloom_critic import BLOOMCritic +from .bloom_rm import BLOOMRM +from .critic import Critic +from .gpt_actor import GPTActor +from .gpt_critic import GPTCritic +from .gpt_rm import GPTRM +from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss +from .opt_actor import OPTActor +from .opt_critic import OPTCritic +from .opt_rm import OPTRM +from .reward_model import RewardModel + +__all__ = [ + 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss', 'GPTActor', + 'GPTCritic', 'GPTRM', 'BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'OPTActor', 'OPTCritic', 'OPTRM' +] diff --git a/applications/ChatGPT/chatgpt/nn/actor.py b/applications/ChatGPT/chatgpt/nn/actor.py new file mode 100644 index 000000000..c4c0d579d --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/actor.py @@ -0,0 +1,62 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .generation import generate +from .lora import LoRAModule +from .utils import log_probs_from_logits + + +class Actor(LoRAModule): + """ + Actor model base class. + + Args: + model (nn.Module): Actor Model. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) + self.model = model + self.convert_to_lora() + + @torch.no_grad() + def generate( + self, + input_ids: torch.Tensor, + return_action_mask: bool = True, + **kwargs + ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: + sequences = generate(self.model, input_ids, **kwargs) + attention_mask = None + pad_token_id = kwargs.get('pad_token_id', None) + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + if not return_action_mask: + return sequences, attention_mask + input_len = input_ids.size(1) + eos_token_id = kwargs.get('eos_token_id', None) + if eos_token_id is None: + action_mask = torch.ones_like(sequences, dtype=torch.bool) + else: + # left padding may be applied, only mask action + action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask[:, :input_len] = False + action_mask = action_mask[:, 1:] + return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] + + def forward(self, + sequences: torch.LongTensor, + num_actions: int, + attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Returns action log probs + """ + output = self.model(sequences, attention_mask=attention_mask) + logits = output['logits'] + log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs[:, -num_actions:] diff --git a/applications/ChatGPT/chatgpt/nn/bloom_actor.py b/applications/ChatGPT/chatgpt/nn/bloom_actor.py new file mode 100644 index 000000000..103536bc3 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/bloom_actor.py @@ -0,0 +1,35 @@ +from typing import Optional + +import torch +from transformers import BloomConfig, BloomForCausalLM, BloomModel + +from .actor import Actor + + +class BLOOMActor(Actor): + """ + BLOOM Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (BloomConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = BloomForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = BloomForCausalLM(config) + else: + model = BloomForCausalLM(BloomConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/nn/bloom_critic.py b/applications/ChatGPT/chatgpt/nn/bloom_critic.py new file mode 100644 index 000000000..3b03471a3 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/bloom_critic.py @@ -0,0 +1,37 @@ +from typing import Optional + +import torch +import torch.nn as nn +from transformers import BloomConfig, BloomForCausalLM, BloomModel + +from .critic import Critic + + +class BLOOMCritic(Critic): + """ + BLOOM Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (BloomConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = BloomModel.from_pretrained(pretrained) + elif config is not None: + model = BloomModel(config) + else: + model = BloomModel(BloomConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/nn/bloom_rm.py b/applications/ChatGPT/chatgpt/nn/bloom_rm.py new file mode 100644 index 000000000..0d4dd43fa --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/bloom_rm.py @@ -0,0 +1,37 @@ +from typing import Optional + +import torch +import torch.nn as nn +from transformers import BloomConfig, BloomForCausalLM, BloomModel + +from .reward_model import RewardModel + + +class BLOOMRM(RewardModel): + """ + BLOOM Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (BloomConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = BloomModel.from_pretrained(pretrained) + elif config is not None: + model = BloomModel(config) + else: + model = BloomModel(BloomConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/nn/critic.py b/applications/ChatGPT/chatgpt/nn/critic.py new file mode 100644 index 000000000..f3a123854 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/critic.py @@ -0,0 +1,47 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from .lora import LoRAModule +from .utils import masked_mean + + +class Critic(LoRAModule): + """ + Critic model base class. + + Args: + model (nn.Module): Critic model. + value_head (nn.Module): Value head to get value. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + model: nn.Module, + value_head: nn.Module, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + + super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) + self.model = model + self.value_head = value_head + self.convert_to_lora() + + def forward(self, + sequences: torch.LongTensor, + action_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + outputs = self.model(sequences, attention_mask=attention_mask) + last_hidden_states = outputs['last_hidden_state'] + + values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1] + + if action_mask is not None: + num_actions = action_mask.size(1) + values = values[:, -num_actions:] + value = masked_mean(values, action_mask, dim=1) + return value + value = values.mean(dim=1).squeeze(1) + return value diff --git a/applications/ChatGPT/chatgpt/nn/generation.py b/applications/ChatGPT/chatgpt/nn/generation.py new file mode 100644 index 000000000..4ee797561 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/generation.py @@ -0,0 +1,137 @@ +from typing import Any, Callable, Optional + +import torch +import torch.nn as nn + +try: + from transformers.generation_logits_process import ( + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) +except ImportError: + from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper + + +def prepare_logits_processor(top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + if temperature is not None and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + processor_list.append(TopKLogitsWarper(top_k)) + if top_p is not None and top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + return processor_list + + +def sample(model: nn.Module, + input_ids: torch.Tensor, + max_length: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: + if input_ids.size(1) >= max_length: + return input_ids + + logits_processor = prepare_logits_processor(top_k, top_p, temperature) + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + for _ in range(input_ids.size(1), max_length): + model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { + 'input_ids': input_ids + } + outputs = model(**model_inputs) + + next_token_logits = outputs['logits'][:, -1, :] + # pre-process distribution + next_token_logits = logits_processor(input_ids, next_token_logits) + # sample + probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if update_model_kwargs_fn is not None: + model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # stop when each sentence is finished if early_stopping=True + if early_stopping and unfinished_sequences.max() == 0: + break + + return input_ids + + +def generate(model: nn.Module, + input_ids: torch.Tensor, + max_length: int, + num_beams: int = 1, + do_sample: bool = True, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: + """Generate token sequence. The returned sequence is input_ids + generated_tokens. + + Args: + model (nn.Module): model + input_ids (torch.Tensor): input sequence + max_length (int): max length of the returned sequence + num_beams (int, optional): number of beams. Defaults to 1. + do_sample (bool, optional): whether to do sample. Defaults to True. + early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False. + eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None. + pad_token_id (Optional[int], optional): pad token id. Defaults to None. + top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. + top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None. + temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None. + prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. + update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. + """ + is_greedy_gen_mode = ((num_beams == 1) and do_sample is False) + is_sample_gen_mode = ((num_beams == 1) and do_sample is True) + is_beam_gen_mode = ((num_beams > 1) and do_sample is False) + if is_greedy_gen_mode: + # run greedy search + raise NotImplementedError + elif is_sample_gen_mode: + # run sample + return sample(model, + input_ids, + max_length, + early_stopping=early_stopping, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + top_k=top_k, + top_p=top_p, + temperature=temperature, + prepare_inputs_fn=prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn, + **model_kwargs) + elif is_beam_gen_mode: + raise NotImplementedError + else: + raise ValueError("Unsupported generation mode") diff --git a/applications/ChatGPT/chatgpt/nn/generation_utils.py b/applications/ChatGPT/chatgpt/nn/generation_utils.py new file mode 100644 index 000000000..c7bc1b383 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/generation_utils.py @@ -0,0 +1,92 @@ +from typing import Optional + +import torch + + +def gpt_prepare_inputs_fn(input_ids: torch.Tensor, past: Optional[torch.Tensor] = None, **kwargs) -> dict: + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + +def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict: + if "past_key_values" in outputs: + model_kwargs["past"] = outputs["past_key_values"] + else: + model_kwargs["past"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) + + return model_kwargs + + +def opt_prepare_inputs_fn(input_ids: torch.Tensor, + past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs) -> dict: + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + +def bloom_prepare_inputs_fn(input_ids: torch.Tensor, + past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs) -> dict: + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } diff --git a/applications/ChatGPT/chatgpt/nn/gpt_actor.py b/applications/ChatGPT/chatgpt/nn/gpt_actor.py new file mode 100644 index 000000000..491182ffa --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/gpt_actor.py @@ -0,0 +1,31 @@ +from typing import Optional + +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + +from .actor import Actor + + +class GPTActor(Actor): + """ + GPT Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (GPT2Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + checkpoint: bool = False) -> None: + if pretrained is not None: + model = GPT2LMHeadModel.from_pretrained(pretrained) + elif config is not None: + model = GPT2LMHeadModel(config) + else: + model = GPT2LMHeadModel(GPT2Config()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model) diff --git a/applications/ChatGPT/chatgpt/nn/gpt_critic.py b/applications/ChatGPT/chatgpt/nn/gpt_critic.py new file mode 100644 index 000000000..b0a001f4a --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/gpt_critic.py @@ -0,0 +1,33 @@ +from typing import Optional + +import torch.nn as nn +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2Model + +from .critic import Critic + + +class GPTCritic(Critic): + """ + GPT Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (GPT2Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + checkpoint: bool = False) -> None: + if pretrained is not None: + model = GPT2Model.from_pretrained(pretrained) + elif config is not None: + model = GPT2Model(config) + else: + model = GPT2Model(GPT2Config()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.n_embd, 1) + super().__init__(model, value_head) diff --git a/applications/ChatGPT/chatgpt/nn/gpt_rm.py b/applications/ChatGPT/chatgpt/nn/gpt_rm.py new file mode 100644 index 000000000..c6c41a45a --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/gpt_rm.py @@ -0,0 +1,33 @@ +from typing import Optional + +import torch.nn as nn +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2Model + +from .reward_model import RewardModel + + +class GPTRM(RewardModel): + """ + GPT Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (GPT2Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + checkpoint: bool = False) -> None: + if pretrained is not None: + model = GPT2Model.from_pretrained(pretrained) + elif config is not None: + model = GPT2Model(config) + else: + model = GPT2Model(GPT2Config()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.n_embd, 1) + super().__init__(model, value_head) diff --git a/applications/ChatGPT/chatgpt/nn/lora.py b/applications/ChatGPT/chatgpt/nn/lora.py new file mode 100644 index 000000000..46a43ec91 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/lora.py @@ -0,0 +1,127 @@ +import math +from typing import Optional + +import loralib as lora +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LoraLinear(lora.LoRALayer, nn.Module): + """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear. + """ + + def __init__( + self, + weight: nn.Parameter, + bias: Optional[nn.Parameter], + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + merge_weights: bool = True, + ): + nn.Module.__init__(self) + lora.LoRALayer.__init__(self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights) + self.weight = weight + self.bias = bias + + out_features, in_features = weight.shape + self.in_features = in_features + self.out_features = out_features + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.T + + def reset_parameters(self): + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def train(self, mode: bool = True): + + def T(w): + return w.T if self.fan_in_fan_out else w + + nn.Module.train(self, mode) + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0: + self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + self.merged = False + + def eval(self): + + def T(w): + return w.T if self.fan_in_fan_out else w + + nn.Module.eval(self) + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + + def T(w): + return w.T if self.fan_in_fan_out else w + + if self.r > 0 and not self.merged: + result = F.linear(x, T(self.weight), bias=self.bias) + if self.r > 0: + result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling + return result + else: + return F.linear(x, T(self.weight), bias=self.bias) + + +def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: + assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})' + lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) + return lora_linear + + +def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr(module, name, lora_linear_wrapper(child, lora_rank)) + else: + convert_to_lora_recursively(child, lora_rank) + + +class LoRAModule(nn.Module): + """A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`. + This calss will convert all torch.nn.Linear layer to LoraLinear layer. + + Args: + lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0. + lora_train_bias (str, optional): Whether LoRA train biases. + 'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers. + Defaults to 'none'. + """ + + def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + super().__init__() + self.lora_rank = lora_rank + self.lora_train_bias = lora_train_bias + + def convert_to_lora(self) -> None: + if self.lora_rank <= 0: + return + convert_to_lora_recursively(self, self.lora_rank) + lora.mark_only_lora_as_trainable(self, self.lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/nn/loss.py b/applications/ChatGPT/chatgpt/nn/loss.py new file mode 100644 index 000000000..0ebcfea06 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/loss.py @@ -0,0 +1,105 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from .utils import masked_mean + + +class GPTLMLoss(nn.Module): + """ + GPT Language Model Loss + """ + + def __init__(self): + super().__init__() + self.loss = nn.CrossEntropyLoss() + + def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +class PolicyLoss(nn.Module): + """ + Policy Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.2) -> None: + super().__init__() + self.clip_eps = clip_eps + + def forward(self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + ratio = (log_probs - old_log_probs).exp() + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + loss = -torch.min(surr1, surr2) + if action_mask is not None: + loss = masked_mean(loss, action_mask) + loss = loss.mean() + return loss + + +class ValueLoss(nn.Module): + """ + Value Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.4) -> None: + super().__init__() + self.clip_eps = clip_eps + + def forward(self, + values: torch.Tensor, + old_values: torch.Tensor, + reward: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) + surr1 = (values_clipped - reward)**2 + surr2 = (values - reward)**2 + loss = torch.max(surr1, surr2) + loss = loss.mean() + return loss + + +class PPOPtxActorLoss(nn.Module): + """ + To Do: + + PPO-ptx Actor Loss + """ + + def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None: + super().__init__() + self.pretrain_coef = pretrain_coef + self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps) + self.pretrain_loss_fn = pretrain_loss_fn + + def forward(self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + lm_logits: torch.Tensor, + lm_input_ids: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask) + lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids) + return policy_loss + self.pretrain_coef * lm_loss + + +class PairWiseLoss(nn.Module): + """ + Pairwise Loss for Reward Model + """ + + def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: + probs = torch.sigmoid(chosen_reward - reject_reward) + log_probs = torch.log(probs) + loss = -log_probs.mean() + return loss diff --git a/applications/ChatGPT/chatgpt/nn/opt_actor.py b/applications/ChatGPT/chatgpt/nn/opt_actor.py new file mode 100644 index 000000000..ff2bf7c00 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/opt_actor.py @@ -0,0 +1,35 @@ +from typing import Optional + +from transformers.models.opt.configuration_opt import OPTConfig +from transformers.models.opt.modeling_opt import OPTForCausalLM + +from .actor import Actor + + +class OPTActor(Actor): + """ + OPT Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (OPTConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = OPTForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = OPTForCausalLM(config) + else: + model = OPTForCausalLM(OPTConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/nn/opt_critic.py b/applications/ChatGPT/chatgpt/nn/opt_critic.py new file mode 100644 index 000000000..9c9cb873f --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/opt_critic.py @@ -0,0 +1,37 @@ +from typing import Optional + +import torch.nn as nn +from transformers.models.opt.configuration_opt import OPTConfig +from transformers.models.opt.modeling_opt import OPTModel + +from .critic import Critic + + +class OPTCritic(Critic): + """ + OPT Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (OPTConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = OPTModel.from_pretrained(pretrained) + elif config is not None: + model = OPTModel(config) + else: + model = OPTModel(OPTConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/nn/opt_rm.py b/applications/ChatGPT/chatgpt/nn/opt_rm.py new file mode 100644 index 000000000..150f832e0 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/opt_rm.py @@ -0,0 +1,33 @@ +from typing import Optional + +import torch.nn as nn +from transformers.models.opt.configuration_opt import OPTConfig +from transformers.models.opt.modeling_opt import OPTModel + +from .reward_model import RewardModel + + +class OPTRM(RewardModel): + """ + OPT Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (OPTConfig): Model config. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = OPTModel.from_pretrained(pretrained) + elif config is not None: + model = OPTModel(config) + else: + model = OPTModel(OPTConfig()) + value_head = nn.Linear(model.config.hidden_size, 1) + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/nn/reward_model.py b/applications/ChatGPT/chatgpt/nn/reward_model.py new file mode 100644 index 000000000..5108f61a6 --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/reward_model.py @@ -0,0 +1,41 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from .lora import LoRAModule + + +class RewardModel(LoRAModule): + """ + Reward model base class. + + Args: + model (nn.Module): Reward model. + value_head (nn.Module): Value head to get reward score. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + model: nn.Module, + value_head: Optional[nn.Module] = None, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) + self.model = model + if value_head is not None: + if value_head.out_features != 1: + raise ValueError("The value head of reward model's output dim should be 1!") + self.value_head = value_head + + else: + self.value_head = nn.Linear(model.config.n_embd, 1) + self.convert_to_lora() + + def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + outputs = self.model(sequences, attention_mask=attention_mask) + last_hidden_states = outputs['last_hidden_state'] + values = self.value_head(last_hidden_states)[:, :-1] + value = values.mean(dim=1).squeeze(1) # ensure shape is (B) + return value diff --git a/applications/ChatGPT/chatgpt/nn/utils.py b/applications/ChatGPT/chatgpt/nn/utils.py new file mode 100644 index 000000000..0ff13181f --- /dev/null +++ b/applications/ChatGPT/chatgpt/nn/utils.py @@ -0,0 +1,92 @@ +from typing import Optional, Union + +import loralib as lora +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def compute_approx_kl(log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Compute the approximate KL divergence between two distributions. + Schulman blog: http://joschu.net/blog/kl-approx.html + + Args: + log_probs: Log probabilities of the new distribution. + log_probs_base: Log probabilities of the base distribution. + action_mask: Mask for actions. + """ + + log_ratio = log_probs - log_probs_base + approx_kl = (log_ratio.exp() - 1) - log_ratio + if action_mask is not None: + approx_kl = masked_mean(approx_kl, action_mask, dim=1) + return approx_kl + approx_kl = approx_kl.mean(dim=1) + return approx_kl + + +def compute_reward(r: Union[torch.Tensor, float], + kl_coef: float, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if kl_coef <= 0.0: + return r + kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) + reward = r - kl_coef * kl + return reward + + +def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + log_probs = F.log_softmax(logits, dim=-1) + log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) + return log_probs_labels.squeeze(-1) + + +def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: + tensor = tensor * mask + tensor = tensor.sum(dim=dim) + mask_sum = mask.sum(dim=dim) + mean = tensor / (mask_sum + 1e-8) + return mean + + +def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor: + tensor = tensor * mask + mean = masked_mean(tensor, mask, dim=dim) + mean_centered = tensor - mean + var = masked_mean(mean_centered**2, mask, dim=dim) + return mean_centered * var.clamp(min=eps).rsqrt() + + +def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor: + mean = tensor.mean(dim) + mean_centered = tensor - mean + var = (mean_centered**2).mean(dim) + norm = mean_centered * var.clamp(min=eps).rsqrt() + return norm + + +def convert_to_lora(model: nn.Module, + input_size: int, + output_size: int, + lora_rank: int = 16, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, + merge_weights: bool = True): + if lora_rank > min(input_size, output_size): + raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}") + + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + module._modules[name] = lora.Linear(input_size, + output_size, + r=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + fan_in_fan_out=fan_in_fan_out, + merge_weights=merge_weights) diff --git a/applications/ChatGPT/chatgpt/replay_buffer/__init__.py b/applications/ChatGPT/chatgpt/replay_buffer/__init__.py new file mode 100644 index 000000000..1ebf60382 --- /dev/null +++ b/applications/ChatGPT/chatgpt/replay_buffer/__init__.py @@ -0,0 +1,4 @@ +from .base import ReplayBuffer +from .naive import NaiveReplayBuffer + +__all__ = ['ReplayBuffer', 'NaiveReplayBuffer'] diff --git a/applications/ChatGPT/chatgpt/replay_buffer/base.py b/applications/ChatGPT/chatgpt/replay_buffer/base.py new file mode 100644 index 000000000..5036b0904 --- /dev/null +++ b/applications/ChatGPT/chatgpt/replay_buffer/base.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Any + +from chatgpt.experience_maker.base import Experience + + +class ReplayBuffer(ABC): + """Replay buffer base class. It stores experience. + + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + """ + + def __init__(self, sample_batch_size: int, limit: int = 0) -> None: + super().__init__() + self.sample_batch_size = sample_batch_size + # limit <= 0 means unlimited + self.limit = limit + + @abstractmethod + def append(self, experience: Experience) -> None: + pass + + @abstractmethod + def clear(self) -> None: + pass + + @abstractmethod + def sample(self) -> Experience: + pass + + @abstractmethod + def __len__(self) -> int: + pass + + @abstractmethod + def __getitem__(self, idx: int) -> Any: + pass + + @abstractmethod + def collate_fn(self, batch: Any) -> Experience: + pass diff --git a/applications/ChatGPT/chatgpt/replay_buffer/naive.py b/applications/ChatGPT/chatgpt/replay_buffer/naive.py new file mode 100644 index 000000000..3fc53da65 --- /dev/null +++ b/applications/ChatGPT/chatgpt/replay_buffer/naive.py @@ -0,0 +1,57 @@ +import random +from typing import List + +import torch +from chatgpt.experience_maker.base import Experience + +from .base import ReplayBuffer +from .utils import BufferItem, make_experience_batch, split_experience_batch + + +class NaiveReplayBuffer(ReplayBuffer): + """Naive replay buffer class. It stores experience. + + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. + """ + + def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None: + super().__init__(sample_batch_size, limit) + self.cpu_offload = cpu_offload + self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}') + # TODO(ver217): add prefetch + self.items: List[BufferItem] = [] + + @torch.no_grad() + def append(self, experience: Experience) -> None: + if self.cpu_offload: + experience.to_device(torch.device('cpu')) + items = split_experience_batch(experience) + self.items.extend(items) + if self.limit > 0: + samples_to_remove = len(self.items) - self.limit + if samples_to_remove > 0: + self.items = self.items[samples_to_remove:] + + def clear(self) -> None: + self.items.clear() + + @torch.no_grad() + def sample(self) -> Experience: + items = random.sample(self.items, self.sample_batch_size) + experience = make_experience_batch(items) + if self.cpu_offload: + experience.to_device(self.target_device) + return experience + + def __len__(self) -> int: + return len(self.items) + + def __getitem__(self, idx: int) -> BufferItem: + return self.items[idx] + + def collate_fn(self, batch) -> Experience: + experience = make_experience_batch(batch) + return experience diff --git a/applications/ChatGPT/chatgpt/replay_buffer/utils.py b/applications/ChatGPT/chatgpt/replay_buffer/utils.py new file mode 100644 index 000000000..752f16704 --- /dev/null +++ b/applications/ChatGPT/chatgpt/replay_buffer/utils.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.nn.functional as F +from chatgpt.experience_maker.base import Experience + + +@dataclass +class BufferItem: + """BufferItem is an item of experience data. + + Shapes of each tensor: + sequences: (S) + action_log_probs: (A) + values: (1) + reward: (1) + advatanges: (1) + attention_mask: (S) + action_mask: (A) + + "A" is the number of actions. + """ + sequences: torch.Tensor + action_log_probs: torch.Tensor + values: torch.Tensor + reward: torch.Tensor + advantages: torch.Tensor + attention_mask: Optional[torch.LongTensor] + action_mask: Optional[torch.BoolTensor] + + +def split_experience_batch(experience: Experience) -> List[BufferItem]: + batch_size = experience.sequences.size(0) + batch_kwargs = [{} for _ in range(batch_size)] + keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') + for key in keys: + value = getattr(experience, key) + if isinstance(value, torch.Tensor): + vals = torch.unbind(value) + else: + # None + vals = [value for _ in range(batch_size)] + assert batch_size == len(vals) + for i, v in enumerate(vals): + batch_kwargs[i][key] = v + items = [BufferItem(**kwargs) for kwargs in batch_kwargs] + return items + + +def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor: + assert side in ('left', 'right') + max_len = max(seq.size(0) for seq in sequences) + padded_sequences = [] + for seq in sequences: + pad_len = max_len - seq.size(0) + padding = (pad_len, 0) if side == 'left' else (0, pad_len) + padded_sequences.append(F.pad(seq, padding)) + return torch.stack(padded_sequences, dim=0) + + +def make_experience_batch(items: List[BufferItem]) -> Experience: + kwargs = {} + to_pad_keys = set(('action_log_probs', 'action_mask')) + keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') + for key in keys: + vals = [getattr(item, key) for item in items] + if key in to_pad_keys: + batch_data = zero_pad_sequences(vals) + else: + batch_data = torch.stack(vals, dim=0) + kwargs[key] = batch_data + return Experience(**kwargs) diff --git a/applications/ChatGPT/chatgpt/trainer/__init__.py b/applications/ChatGPT/chatgpt/trainer/__init__.py new file mode 100644 index 000000000..c47c76347 --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/__init__.py @@ -0,0 +1,5 @@ +from .base import Trainer +from .ppo import PPOTrainer +from .rm import RewardModelTrainer + +__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer'] diff --git a/applications/ChatGPT/chatgpt/trainer/base.py b/applications/ChatGPT/chatgpt/trainer/base.py new file mode 100644 index 000000000..42547af78 --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/base.py @@ -0,0 +1,162 @@ +import random +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from chatgpt.experience_maker import Experience, ExperienceMaker +from chatgpt.replay_buffer import ReplayBuffer +from torch import Tensor +from torch.utils.data import DistributedSampler +from tqdm import tqdm + +from .callbacks import Callback +from .strategies import Strategy +from .utils import is_rank_0 + + +class Trainer(ABC): + """ + Base class for rlhf trainers. + + Args: + strategy (Strategy):the strategy to use for training + experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer + replay_buffer (ReplayBuffer): the replay buffer to use for training + experience_batch_size (int, defaults to 8): the batch size to use for experience generation + max_epochs (int, defaults to 1): the number of epochs of training process + tokenizer (Callable, optional): the tokenizer to use for tokenizing the input + sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer + data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + """ + + def __init__(self, + strategy: Strategy, + experience_maker: ExperienceMaker, + replay_buffer: ReplayBuffer, + experience_batch_size: int = 8, + max_epochs: int = 1, + tokenizer: Optional[Callable[[Any], dict]] = None, + sample_replay_buffer: bool = False, + dataloader_pin_memory: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs) -> None: + super().__init__() + self.strategy = strategy + self.experience_maker = experience_maker + self.replay_buffer = replay_buffer + self.experience_batch_size = experience_batch_size + self.max_epochs = max_epochs + self.tokenizer = tokenizer + self.generate_kwargs = generate_kwargs + self.sample_replay_buffer = sample_replay_buffer + self.dataloader_pin_memory = dataloader_pin_memory + self.callbacks = callbacks + + @abstractmethod + def training_step(self, experience: Experience) -> Dict[str, Any]: + pass + + def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: + if isinstance(inputs, Tensor): + return self.experience_maker.make_experience(inputs, **self.generate_kwargs) + elif isinstance(inputs, dict): + return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) + else: + raise ValueError(f'Unsupported input type "{type(inputs)}"') + + def _sample_prompts(self, prompts) -> list: + indices = list(range(len(prompts))) + sampled_indices = random.sample(indices, self.experience_batch_size) + return [prompts[i] for i in sampled_indices] + + def _learn(self): + # replay buffer may be empty at first, we should rebuild at each training + if not self.sample_replay_buffer: + dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory) + device = torch.cuda.current_device() + if self.sample_replay_buffer: + pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) + for _ in pbar: + experience = self.replay_buffer.sample() + metrics = self.training_step(experience) + pbar.set_postfix(metrics) + else: + for epoch in range(self.max_epochs): + self._on_learn_epoch_start(epoch) + if isinstance(dataloader.sampler, DistributedSampler): + dataloader.sampler.set_epoch(epoch) + pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0()) + for experience in pbar: + self._on_learn_batch_start() + experience.to_device(device) + metrics = self.training_step(experience) + self._on_learn_batch_end(metrics, experience) + pbar.set_postfix(metrics) + self._on_learn_epoch_end(epoch) + + def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None: + time = 0 + self._on_fit_start() + for episode in range(num_episodes): + self._on_episode_start(episode) + for timestep in tqdm(range(max_timesteps), + desc=f'Episode [{episode+1}/{num_episodes}]', + disable=not is_rank_0()): + time += 1 + rand_prompts = self._sample_prompts(prompts) + if self.tokenizer is not None: + inputs = self.tokenizer(rand_prompts) + else: + inputs = rand_prompts + self._on_make_experience_start() + experience = self._make_experience(inputs) + self._on_make_experience_end(experience) + self.replay_buffer.append(experience) + if time % update_timesteps == 0: + self._learn() + self.replay_buffer.clear() + self._on_episode_end(episode) + self._on_fit_end() + + # TODO(ver217): maybe simplify these code using context + def _on_fit_start(self) -> None: + for callback in self.callbacks: + callback.on_fit_start() + + def _on_fit_end(self) -> None: + for callback in self.callbacks: + callback.on_fit_end() + + def _on_episode_start(self, episode: int) -> None: + for callback in self.callbacks: + callback.on_episode_start(episode) + + def _on_episode_end(self, episode: int) -> None: + for callback in self.callbacks: + callback.on_episode_end(episode) + + def _on_make_experience_start(self) -> None: + for callback in self.callbacks: + callback.on_make_experience_start() + + def _on_make_experience_end(self, experience: Experience) -> None: + for callback in self.callbacks: + callback.on_make_experience_end(experience) + + def _on_learn_epoch_start(self, epoch: int) -> None: + for callback in self.callbacks: + callback.on_learn_epoch_start(epoch) + + def _on_learn_epoch_end(self, epoch: int) -> None: + for callback in self.callbacks: + callback.on_learn_epoch_end(epoch) + + def _on_learn_batch_start(self) -> None: + for callback in self.callbacks: + callback.on_learn_batch_start() + + def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + for callback in self.callbacks: + callback.on_learn_batch_end(metrics, experience) diff --git a/applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py b/applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py new file mode 100644 index 000000000..79ea9ffcd --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py @@ -0,0 +1,4 @@ +from .base import Callback +from .performance_evaluator import PerformanceEvaluator + +__all__ = ['Callback', 'PerformanceEvaluator'] diff --git a/applications/ChatGPT/chatgpt/trainer/callbacks/base.py b/applications/ChatGPT/chatgpt/trainer/callbacks/base.py new file mode 100644 index 000000000..0b01345f7 --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/callbacks/base.py @@ -0,0 +1,39 @@ +from abc import ABC + +from chatgpt.experience_maker import Experience + + +class Callback(ABC): + """ + Base callback class. It defines the interface for callbacks. + """ + + def on_fit_start(self) -> None: + pass + + def on_fit_end(self) -> None: + pass + + def on_episode_start(self, episode: int) -> None: + pass + + def on_episode_end(self, episode: int) -> None: + pass + + def on_make_experience_start(self) -> None: + pass + + def on_make_experience_end(self, experience: Experience) -> None: + pass + + def on_learn_epoch_start(self, epoch: int) -> None: + pass + + def on_learn_epoch_end(self, epoch: int) -> None: + pass + + def on_learn_batch_start(self) -> None: + pass + + def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + pass diff --git a/applications/ChatGPT/chatgpt/trainer/callbacks/performance_evaluator.py b/applications/ChatGPT/chatgpt/trainer/callbacks/performance_evaluator.py new file mode 100644 index 000000000..faa38af1b --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/callbacks/performance_evaluator.py @@ -0,0 +1,133 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +from chatgpt.experience_maker import Experience + +from .base import Callback + + +def get_world_size() -> int: + if dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def print_rank_0(*args, **kwargs) -> None: + if not dist.is_initialized() or dist.get_rank() == 0: + print(*args, **kwargs) + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + tensor = torch.tensor([x], device=torch.cuda.current_device()) + dist.all_reduce(tensor) + tensor = tensor / world_size + return tensor.item() + + +class PerformanceEvaluator(Callback): + """ + Callback for valuate the performance of the model. + Args: + actor_num_params: The number of parameters of the actor model. + critic_num_params: The number of parameters of the critic model. + initial_model_num_params: The number of parameters of the initial model. + reward_model_num_params: The number of parameters of the reward model. + enable_grad_checkpoint: Whether to enable gradient checkpointing. + ignore_episodes: The number of episodes to ignore when calculating the performance. + """ + + def __init__(self, + actor_num_params: int, + critic_num_params: int, + initial_model_num_params: int, + reward_model_num_params: int, + enable_grad_checkpoint: bool = False, + ignore_episodes: int = 0) -> None: + super().__init__() + self.world_size = get_world_size() + self.actor_num_params = actor_num_params + self.critic_num_params = critic_num_params + self.initial_model_num_params = initial_model_num_params + self.reward_model_num_params = reward_model_num_params + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_episodes = ignore_episodes + self.disable: bool = False + + self.make_experience_duration: float = 0. + self.make_experience_start_time: Optional[float] = None + self.make_experience_num_samples: int = 0 + self.make_experience_flop: int = 0 + self.learn_duration: float = 0. + self.learn_start_time: Optional[float] = None + self.learn_num_samples: int = 0 + self.learn_flop: int = 0 + + def on_episode_start(self, episode: int) -> None: + self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes + + def on_make_experience_start(self) -> None: + if self.disable: + return + self.make_experience_start_time = time() + + def on_make_experience_end(self, experience: Experience) -> None: + if self.disable: + return + self.make_experience_duration += time() - self.make_experience_start_time + + batch_size, seq_len = experience.sequences.shape + + self.make_experience_num_samples += batch_size + + # actor generate + num_actions = experience.action_mask.size(1) + input_len = seq_len - num_actions + total_seq_len = (input_len + seq_len - 1) * num_actions / 2 + self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2 + # actor forward + self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2 + # critic forward + self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2 + # initial model forward + self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2 + # reward model forward + self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2 + + def on_learn_batch_start(self) -> None: + if self.disable: + return + self.learn_start_time = time() + + def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + if self.disable: + return + self.learn_duration += time() - self.learn_start_time + + batch_size, seq_len = experience.sequences.shape + + self.learn_num_samples += batch_size + + # actor forward-backward, 3 means forward(1) + backward(2) + self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) + # critic foward-backward + self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) + + def on_fit_end(self) -> None: + avg_make_experience_duration = all_reduce_mean(self.make_experience_duration, self.world_size) + avg_learn_duration = all_reduce_mean(self.learn_duration, self.world_size) + + avg_make_experience_throughput = self.make_experience_num_samples / (avg_make_experience_duration + 1e-12) + avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) + + avg_learn_throughput = self.learn_num_samples / (avg_learn_duration + 1e-12) + avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12) + + print_rank_0( + f'Making experience throughput: {avg_make_experience_throughput:.3f} samples/sec, TFLOPS: {avg_make_experience_tflops:.3f}' + ) + print_rank_0(f'Learning throughput: {avg_learn_throughput:.3f} samples/sec, TFLOPS: {avg_learn_tflops:.3f}') diff --git a/applications/ChatGPT/chatgpt/trainer/ppo.py b/applications/ChatGPT/chatgpt/trainer/ppo.py new file mode 100644 index 000000000..85beb223e --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/ppo.py @@ -0,0 +1,104 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch.nn as nn +from chatgpt.experience_maker import Experience, NaiveExperienceMaker +from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss +from chatgpt.replay_buffer import NaiveReplayBuffer +from torch.optim import Optimizer + +from .base import Trainer +from .callbacks import Callback +from .strategies import Strategy + + +class PPOTrainer(Trainer): + """ + Trainer for PPO algorithm. + + Args: + strategy (Strategy): the strategy to use for training + actor (Actor): the actor model in ppo algorithm + critic (Critic): the critic model in ppo algorithm + reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences + initial_model (Actor): the initial model in rlhf algorithm to generate reference logits to limit the update of actor + actor_optim (Optimizer): the optimizer to use for actor model + critic_optim (Optimizer): the optimizer to use for critic model + kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss + train_batch_size (int, defaults to 8): the batch size to use for training + buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer + buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu + eps_clip (float, defaults to 0.2): the clip coefficient of policy loss + value_clip (float, defaults to 0.4): the clip coefficient of value loss + experience_batch_size (int, defaults to 8): the batch size to use for experience generation + max_epochs (int, defaults to 1): the number of epochs of training process + tokenier (Callable, optional): the tokenizer to use for tokenizing the input + sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + """ + + def __init__(self, + strategy: Strategy, + actor: Actor, + critic: Critic, + reward_model: nn.Module, + initial_model: Actor, + actor_optim: Optimizer, + critic_optim: Optimizer, + kl_coef: float = 0.1, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + value_clip: float = 0.4, + experience_batch_size: int = 8, + max_epochs: int = 1, + tokenizer: Optional[Callable[[Any], dict]] = None, + sample_replay_buffer: bool = False, + dataloader_pin_memory: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs) -> None: + actor = Actor(strategy.setup_model(actor.model)) + critic = strategy.setup_model(critic) + reward_model = strategy.setup_model(reward_model) + initial_model = Actor(strategy.setup_model(initial_model.model)) + experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) + replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) + super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer, + sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs) + self.actor = actor + self.critic = critic + + self.actor_loss_fn = PolicyLoss(eps_clip) + self.critic_loss_fn = ValueLoss(value_clip) + + self.actor_optim = strategy.setup_optimizer(actor_optim, self.actor.model) + self.critic_optim = strategy.setup_optimizer(critic_optim, self.critic) + + def training_step(self, experience: Experience) -> Dict[str, float]: + self.actor.train() + self.critic.train() + + num_actions = experience.action_mask.size(1) + action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) + actor_loss = self.actor_loss_fn(action_log_probs, + experience.action_log_probs, + experience.advantages, + action_mask=experience.action_mask) + self.strategy.backward(actor_loss, self.actor, self.actor_optim) + self.strategy.optimizer_step(self.actor_optim) + self.actor_optim.zero_grad() + + values = self.critic(experience.sequences, + action_mask=experience.action_mask, + attention_mask=experience.attention_mask) + critic_loss = self.critic_loss_fn(values, + experience.values, + experience.reward, + action_mask=experience.action_mask) + self.strategy.backward(critic_loss, self.critic, self.critic_optim) + self.strategy.optimizer_step(self.critic_optim) + self.critic_optim.zero_grad() + + return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} diff --git a/applications/ChatGPT/chatgpt/trainer/rm.py b/applications/ChatGPT/chatgpt/trainer/rm.py new file mode 100644 index 000000000..c24289502 --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/rm.py @@ -0,0 +1,77 @@ +from abc import ABC + +import loralib as lora +from chatgpt.dataset import RewardDataset +from chatgpt.nn import PairWiseLoss +from torch.optim import Adam +from torch.utils.data import DataLoader +from tqdm import tqdm + + +class RewardModelTrainer(ABC): + """ + Trainer to use while training reward model. + + Args: + model (torch.nn.Module): the model to train + train_dataset (RewardDataset): the dataset to use for training + eval_dataset (RewardDataset): the dataset to use for evaluation + batch_size (int, defaults to 1): the batch size while training + num_epochs (int, defaults to 2): the number of epochs to train + optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer + """ + + def __init__(self, + model, + train_dataset: RewardDataset, + eval_dataset: RewardDataset, + batch_size: int = 1, + num_epochs: int = 2, + optim_kwargs: dict = {'lr': 1e-4}) -> None: + super().__init__() + self.model = model + self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size) + self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size) + self.loss_fn = PairWiseLoss() + self.optimizer = Adam(self.model.parameters(), **optim_kwargs) + self.epochs = num_epochs + + def fit(self, use_lora): + epoch_bar = tqdm(range(self.epochs), desc='Train epoch') + for epoch in range(self.epochs): + step_bar = tqdm(range(self.train_dataloader.__len__()), desc='Train step of epoch %d' % epoch) + # train + if use_lora > 0: + print("Using Lora") + lora.mark_only_lora_as_trainable(self.model) + else: + self.model.train() + for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: + chosen_ids = chosen_ids.squeeze(1).cuda() + c_mask = c_mask.squeeze(1).cuda() + reject_ids = reject_ids.squeeze(1).cuda() + r_mask = r_mask.squeeze(1).cuda() + chosen_reward = self.model(chosen_ids, attention_mask=c_mask) + reject_reward = self.model(reject_ids, attention_mask=r_mask) + loss = self.loss_fn(chosen_reward, reject_reward) + loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + step_bar.update() + step_bar.set_postfix({'loss': loss.item()}) + + # eval + self.model.eval() + for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader: + dist = 0 + chosen_ids = chosen_ids.squeeze(1).cuda() + c_mask = c_mask.squeeze(1).cuda() + reject_ids = reject_ids.squeeze(1).cuda() + r_mask = r_mask.squeeze(1).cuda() + chosen_reward = self.model(chosen_ids, attention_mask=c_mask) + reject_reward = self.model(reject_ids, attention_mask=r_mask) + dist += (chosen_reward - reject_reward) + dist_mean = dist / self.eval_dataloader.__len__() + epoch_bar.update() + step_bar.set_postfix({'loss': loss.item(), 'dist_mean': dist_mean.item()}) + step_bar.close() diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/__init__.py b/applications/ChatGPT/chatgpt/trainer/strategies/__init__.py new file mode 100644 index 000000000..f258c9b8a --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/strategies/__init__.py @@ -0,0 +1,6 @@ +from .base import Strategy +from .colossalai import ColossalAIStrategy +from .ddp import DDPStrategy +from .naive import NaiveStrategy + +__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy'] diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/base.py b/applications/ChatGPT/chatgpt/trainer/strategies/base.py new file mode 100644 index 000000000..3a2923b8c --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/strategies/base.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod +from contextlib import nullcontext + +import torch +import torch.nn as nn +import torch.optim as optim +from chatgpt.replay_buffer import ReplayBuffer +from torch.utils.data import DataLoader + + +class Strategy(ABC): + """ + Base class for training strategies. + """ + + def __init__(self) -> None: + super().__init__() + self.setup_distributed() + + @abstractmethod + def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: + pass + + @abstractmethod + def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: + pass + + @abstractmethod + def setup_distributed(self) -> None: + pass + + @abstractmethod + def setup_model(self, model: nn.Module) -> nn.Module: + pass + + @abstractmethod + def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: + pass + + @abstractmethod + def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: + pass + + def model_init_context(self): + return nullcontext() diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py new file mode 100644 index 000000000..665bfa913 --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py @@ -0,0 +1,125 @@ +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim + +import colossalai +from colossalai.nn.optimizer import CPUAdam, HybridAdam +from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.tensor import ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + +from .ddp import DDPStrategy + + +class ColossalAIStrategy(DDPStrategy): + """ + The strategy for training with ColossalAI. + + Args: + stage(int): The stage to use in ZeRO. Choose in (1, 2, 3) + seed(int): The seed for the random number generator. + shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3. + placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda') + If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU, + If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest. + pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3. + force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3. + search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3. + hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3. + min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3. + gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3. + reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2. + overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2. + initial_scale(float): The initial scale for the optimizer. + growth_factor(float): The growth factor for the optimizer. + backoff_factor(float): The backoff factor for the optimizer. + growth_interval(int): The growth interval for the optimizer. + hysteresis(int): The hysteresis for the optimizer. + min_scale(float): The minimum scale for the optimizer. + max_scale(float): The maximum scale for the optimizer. + max_norm(float): The maximum norm for the optimizer. + norm_type(float): The norm type for the optimizer. + + """ + + def __init__( + self, + stage: int = 3, + seed: int = 42, + shard_init: bool = True, # only for stage 3 + placement_policy: str = 'cuda', + pin_memory: bool = True, # only for stage 3 + force_outputs_fp32: bool = False, # only for stage 3 + search_range_mb: int = 32, # only for stage 3 + hidden_dim: Optional[int] = None, # only for stage 3 + min_chunk_size_mb: float = 32, # only for stage 3 + gpu_margin_mem_ratio: float = 0.0, # only for stage 3 + reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 + overlap_communication: bool = True, # only for stage 1&2 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0) -> None: + super().__init__(seed) + assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' + self.stage = stage + self.shard_init = shard_init + self.gemini_config = dict(device=get_current_device(), + placement_policy=placement_policy, + pin_memory=pin_memory, + force_outputs_fp32=force_outputs_fp32, + strict_ddp_mode=shard_init, + search_range_mb=search_range_mb, + hidden_dim=hidden_dim, + min_chunk_size_mb=min_chunk_size_mb) + if stage == 3: + self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio) + else: + self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size, + overlap_communication=overlap_communication, + cpu_offload=(placement_policy == 'cpu')) + self.optim_kwargs = dict(initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) + + def setup_distributed(self) -> None: + colossalai.launch_from_torch({}, seed=self.seed) + + def model_init_context(self): + if self.stage == 3: + world_size = dist.get_world_size() + shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None + default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None + return ColoInitContext(device=get_current_device(), + dtype=torch.half, + default_pg=shard_pg, + default_dist_spec=default_dist_spec) + return super().model_init_context() + + def setup_model(self, model: nn.Module) -> nn.Module: + return zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) + + def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: + assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}' + return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs) + + def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: + optimizer.backward(loss) + + def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: + optimizer.step() diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py b/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py new file mode 100644 index 000000000..b636515b4 --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py @@ -0,0 +1,59 @@ +import os +import random + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from chatgpt.replay_buffer import ReplayBuffer +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler + +from .naive import NaiveStrategy + + +class DDPStrategy(NaiveStrategy): + """ + Strategy for distributed training using torch.distributed. + """ + + def __init__(self, seed: int = 42) -> None: + self.seed = seed + super().__init__() + + def setup_distributed(self) -> None: + try: + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + host = os.environ['MASTER_ADDR'] + port = int(os.environ['MASTER_PORT']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" + ) + dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank) + self.set_seed(self.seed) + torch.cuda.set_device(local_rank) + + def set_seed(self, seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + def setup_model(self, model: nn.Module) -> nn.Module: + device = torch.cuda.current_device() + return DDP(model, device_ids=[device]) + + def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: + sampler = DistributedSampler(replay_buffer, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=True, + seed=self.seed, + drop_last=True) + return DataLoader(replay_buffer, + batch_size=replay_buffer.sample_batch_size, + sampler=sampler, + pin_memory=pin_memory, + collate_fn=replay_buffer.collate_fn) diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/naive.py b/applications/ChatGPT/chatgpt/trainer/strategies/naive.py new file mode 100644 index 000000000..1bb472ae6 --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/strategies/naive.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from chatgpt.replay_buffer import ReplayBuffer +from torch.utils.data import DataLoader + +from .base import Strategy + + +class NaiveStrategy(Strategy): + """ + Strategy for single GPU. No parallelism is used. + """ + + def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: + loss.backward() + + def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: + optimizer.step() + + def setup_distributed(self) -> None: + pass + + def setup_model(self, model: nn.Module) -> nn.Module: + return model + + def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: + return optimizer + + def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: + return DataLoader(replay_buffer, + batch_size=replay_buffer.sample_batch_size, + shuffle=True, + drop_last=True, + pin_memory=pin_memory, + collate_fn=replay_buffer.collate_fn) diff --git a/applications/ChatGPT/chatgpt/trainer/utils.py b/applications/ChatGPT/chatgpt/trainer/utils.py new file mode 100644 index 000000000..6c9f7f085 --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/utils.py @@ -0,0 +1,5 @@ +import torch.distributed as dist + + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 diff --git a/applications/ChatGPT/examples/README.md b/applications/ChatGPT/examples/README.md new file mode 100644 index 000000000..5f9d8698d --- /dev/null +++ b/applications/ChatGPT/examples/README.md @@ -0,0 +1,105 @@ +# Examples + +## Install requirements + +```shell +pip install -r requirements.txt +``` + +## Train with dummy prompt data + +This script supports 3 strategies: + +- naive +- ddp +- colossalai + +It uses random generated prompt data. + +Naive strategy only support single GPU training: + +```shell +python train_dummy.py --strategy naive +# display cli help +python train_dummy.py -h +``` + +DDP strategy and ColossalAI strategy support multi GPUs training: + +```shell +# run DDP on 2 GPUs +torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy ddp +# run ColossalAI on 2 GPUs +torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai +``` + +## Train with real prompt data + +We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-chatgpt-prompts) as example dataset. It is a small dataset with hundreds of prompts. + +You should download `prompts.csv` first. + +This script also supports 3 strategies. + +```shell +# display cli help +python train_dummy.py -h +# run naive on 1 GPU +python train_prompts.py prompts.csv --strategy naive +# run DDP on 2 GPUs +torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy ddp +# run ColossalAI on 2 GPUs +torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai +``` + +## Train the reward model +We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt. + +You can download the dataset from huggingface automatically. + +Use these code to train your reward model. + +```shell +# Naive reward model training +python train_reward_model.py --pretrain +# if to use LoRA +python train_reward_model.py --pretrain --lora_rank 16 +``` + +## Support Model + +### GPT +- [ ] GPT2-S (s) +- [ ] GPT2-M (m) +- [ ] GPT2-L (l) +- [ ] GPT2-XL (xl) +- [ ] GPT2-4B (4b) +- [ ] GPT2-6B (6b) +- [ ] GPT2-8B (8b) +- [ ] GPT2-10B (10b) +- [ ] GPT2-12B (12b) +- [ ] GPT2-15B (15b) +- [ ] GPT2-18B (18b) +- [ ] GPT2-20B (20b) +- [ ] GPT2-24B (24b) +- [ ] GPT2-28B (28b) +- [ ] GPT2-32B (32b) +- [ ] GPT2-36B (36b) +- [ ] GPT2-40B (40b) +- [ ] GPT3 (175b) + +### BLOOM +- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m) +- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1) +- [ ] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b) +- [ ] [BLOOM-7b](https://huggingface.co/bigscience/bloomz-7b1) +- [ ] BLOOM-175b + +### OPT +- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m) +- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m) +- [ ] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b) +- [ ] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b) +- [ ] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b) +- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b) +- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b) diff --git a/applications/ChatGPT/examples/requirements.txt b/applications/ChatGPT/examples/requirements.txt new file mode 100644 index 000000000..6c5dac292 --- /dev/null +++ b/applications/ChatGPT/examples/requirements.txt @@ -0,0 +1 @@ +pandas>=1.4.1 diff --git a/applications/ChatGPT/examples/test_ci.sh b/applications/ChatGPT/examples/test_ci.sh new file mode 100755 index 000000000..c4a5ead1d --- /dev/null +++ b/applications/ChatGPT/examples/test_ci.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +set -xue + +if [ -z "$PROMPT_PATH" ]; then + echo "Please set \$PROMPT_PATH to the path to prompts csv." + exit 1 +fi + +BASE=$(realpath $(dirname $0)) + +export OMP_NUM_THREADS=8 + +# install requirements +pip install -r ${BASE}/requirements.txt + +# train dummy +python ${BASE}/train_dummy.py --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2 +for strategy in ddp colossalai_gemini colossalai_zero2; do + torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2 +done + +# train prompts +python ${BASE}/train_prompts.py $PROMPT_PATH --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 +for strategy in ddp colossalai_gemini colossalai_zero2; do + torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2 +done diff --git a/applications/ChatGPT/examples/train_dummy.py b/applications/ChatGPT/examples/train_dummy.py new file mode 100644 index 000000000..313be2c3b --- /dev/null +++ b/applications/ChatGPT/examples/train_dummy.py @@ -0,0 +1,121 @@ +import argparse +from copy import deepcopy + +import torch +from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel +from chatgpt.nn.generation_utils import ( + bloom_prepare_inputs_fn, + gpt_prepare_inputs_fn, + opt_prepare_inputs_fn, + update_model_kwargs_fn, +) +from chatgpt.trainer import PPOTrainer +from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + + +def preprocess_batch(samples): + input_ids = torch.stack(samples) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + return {'input_ids': input_ids, 'attention_mask': attention_mask} + + +def main(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + # configure model + with strategy.model_init_context(): + if args.model == 'gpt2': + actor = GPTActor().cuda() + critic = GPTCritic().cuda() + elif args.model == 'bloom': + actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + elif args.model == 'opt': + actor = OPTActor().cuda() + critic = OPTCritic().cuda() + else: + raise ValueError(f'Unsupported model "{args.model}"') + + initial_model = deepcopy(actor).cuda() + reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() + + # configure optimizer + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=5e-6) + critic_optim = HybridAdam(critic.parameters(), lr=5e-6) + else: + actor_optim = Adam(actor.parameters(), lr=5e-6) + critic_optim = Adam(critic.parameters(), lr=5e-6) + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + prepare_inputs_fn = gpt_prepare_inputs_fn + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + prepare_inputs_fn = bloom_prepare_inputs_fn + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + prepare_inputs_fn = opt_prepare_inputs_fn + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure trainer + trainer = PPOTrainer(strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + tokenizer=preprocess_batch, + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + prepare_inputs_fn=prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn) + + random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device()) + trainer.fit(random_prompts, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--num_episodes', type=int, default=50) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + args = parser.parse_args() + main(args) diff --git a/applications/ChatGPT/examples/train_dummy.sh b/applications/ChatGPT/examples/train_dummy.sh new file mode 100755 index 000000000..559d338ee --- /dev/null +++ b/applications/ChatGPT/examples/train_dummy.sh @@ -0,0 +1,18 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 1 + +python train_dummy.py --model bloom --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16 diff --git a/applications/ChatGPT/examples/train_prompts.py b/applications/ChatGPT/examples/train_prompts.py new file mode 100644 index 000000000..994b10fe0 --- /dev/null +++ b/applications/ChatGPT/examples/train_prompts.py @@ -0,0 +1,113 @@ +import argparse +from copy import deepcopy + +import pandas as pd +from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel +from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn +from chatgpt.trainer import PPOTrainer +from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + + +def main(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + # configure model + with strategy.model_init_context(): + if args.model == 'gpt2': + actor = GPTActor().cuda() + critic = GPTCritic().cuda() + elif args.model == 'bloom': + actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + elif args.model == 'opt': + actor = OPTActor(lora_rank=args.lora_rank).cuda() + critic = OPTCritic(lora_rank=args.lora_rank).cuda() + else: + raise ValueError(f'Unsupported model "{args.model}"') + + initial_model = deepcopy(actor) + reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() + + # configure optimizer + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=5e-6) + critic_optim = HybridAdam(critic.parameters(), lr=5e-6) + else: + actor_optim = Adam(actor.parameters(), lr=5e-6) + critic_optim = Adam(critic.parameters(), lr=5e-6) + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + dataset = pd.read_csv(args.prompt_path)['prompt'] + + def tokenize_fn(texts): + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + # configure trainer + trainer = PPOTrainer(strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + tokenizer=tokenize_fn, + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + prepare_inputs_fn=gpt_prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn) + + trainer.fit(dataset, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('prompt_path') + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + args = parser.parse_args() + main(args) diff --git a/applications/ChatGPT/examples/train_prompts.sh b/applications/ChatGPT/examples/train_prompts.sh new file mode 100755 index 000000000..0b82d3f1c --- /dev/null +++ b/applications/ChatGPT/examples/train_prompts.sh @@ -0,0 +1,18 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 1 + +python train_prompts.py prompts.csv --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16 diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py new file mode 100644 index 000000000..fd78a2ac6 --- /dev/null +++ b/applications/ChatGPT/examples/train_reward_model.py @@ -0,0 +1,53 @@ +import argparse + +import loralib as lora +import torch +from chatgpt.dataset import RewardDataset +from chatgpt.nn import BLOOMRM +from chatgpt.trainer import RewardModelTrainer +from datasets import load_dataset +from transformers import BloomTokenizerFast + + +def train(args): + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + model = BLOOMRM(pretrained=args.pretrain) + + model.cuda() + + max_len = 1024 + + # prepare for data and dataset + data = load_dataset(args.dataset) + train_data = data["train"] + eval_data = data['test'] + train_dataset = RewardDataset(train_data, tokenizer, max_len) + eval_dataset = RewardDataset(eval_data, tokenizer, max_len) + + # batch_size here is expected to be C(k,2), k means # response of each prompt + # be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1 + trainer = RewardModelTrainer(model=model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + batch_size=args.batch_size, + num_epochs=args.max_epochs) + + trainer.fit(use_lora=args.lora_rank) + + if args.lora_rank > 0: + torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path) + else: + torch.save(trainer.model, args.save_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--dataset', type=str, default='Dahoas/rm-static') + parser.add_argument('--save_path', type=str, default='rm_ckpt.pth') + parser.add_argument('--max_epochs', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + args = parser.parse_args() + train(args) diff --git a/applications/ChatGPT/examples/train_rm.sh b/applications/ChatGPT/examples/train_rm.sh new file mode 100755 index 000000000..bf46d7e43 --- /dev/null +++ b/applications/ChatGPT/examples/train_rm.sh @@ -0,0 +1,18 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 1 + +python train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16 diff --git a/applications/ChatGPT/pytest.ini b/applications/ChatGPT/pytest.ini new file mode 100644 index 000000000..01e5cd217 --- /dev/null +++ b/applications/ChatGPT/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +markers = + cpu: tests which can run on CPU + gpu: tests which requires a single GPU + dist: tests which are run in a multi-GPU or multi-machine environment + experiment: tests for experimental features diff --git a/applications/ChatGPT/requirements/requirements-test.txt b/applications/ChatGPT/requirements/requirements-test.txt new file mode 100644 index 000000000..e079f8a60 --- /dev/null +++ b/applications/ChatGPT/requirements/requirements-test.txt @@ -0,0 +1 @@ +pytest diff --git a/applications/ChatGPT/requirements/requirements.txt b/applications/ChatGPT/requirements/requirements.txt new file mode 100644 index 000000000..87f6a52cc --- /dev/null +++ b/applications/ChatGPT/requirements/requirements.txt @@ -0,0 +1,6 @@ +transformers>=4.20.1 +tqdm +datasets +loralib +colossalai>=0.2.4 +torch diff --git a/applications/ChatGPT/setup.py b/applications/ChatGPT/setup.py new file mode 100644 index 000000000..f9607190a --- /dev/null +++ b/applications/ChatGPT/setup.py @@ -0,0 +1,42 @@ +from setuptools import find_packages, setup + + +def fetch_requirements(path): + with open(path, 'r') as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open('README.md', encoding='utf-8') as f: + return f.read() + + +def fetch_version(): + with open('version.txt', 'r') as f: + return f.read().strip() + + +setup( + name='chatgpt', + version=fetch_version(), + packages=find_packages(exclude=( + 'tests', + 'benchmarks', + 'requirements', + '*.egg-info', + )), + description='A RLFH implementation (ChatGPT) powered by ColossalAI', + long_description=fetch_readme(), + long_description_content_type='text/markdown', + license='Apache Software License 2.0', + url='https://github.com/hpcaitech/ChatGPT', + install_requires=fetch_requirements('requirements/requirements.txt'), + python_requires='>=3.6', + classifiers=[ + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: Apache Software License', + 'Environment :: GPU :: NVIDIA CUDA', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: System :: Distributed Computing', + ], +) diff --git a/applications/ChatGPT/tests/__init__.py b/applications/ChatGPT/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/ChatGPT/tests/test_data.py b/applications/ChatGPT/tests/test_data.py new file mode 100644 index 000000000..b0a9433c2 --- /dev/null +++ b/applications/ChatGPT/tests/test_data.py @@ -0,0 +1,117 @@ +import os +from copy import deepcopy +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from chatgpt.experience_maker import NaiveExperienceMaker +from chatgpt.nn import GPTActor, GPTCritic, RewardModel +from chatgpt.replay_buffer import NaiveReplayBuffer +from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy + +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def get_data(batch_size: int, seq_len: int = 10) -> dict: + input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') + attention_mask = torch.ones_like(input_ids) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def gather_and_equal(tensor: torch.Tensor) -> bool: + world_size = dist.get_world_size() + outputs = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(outputs, tensor.contiguous()) + for t in outputs[1:]: + if not torch.equal(outputs[0], t): + return False + return True + + +def run_test_data(strategy): + EXPERINCE_BATCH_SIZE = 4 + SAMPLE_BATCH_SIZE = 2 + + if strategy == 'ddp': + strategy = DDPStrategy() + elif strategy == 'colossalai': + strategy = ColossalAIStrategy(placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{strategy}"') + + actor = GPTActor().cuda() + critic = GPTCritic().cuda() + + initial_model = deepcopy(actor) + reward_model = RewardModel(deepcopy(critic.model)).cuda() + + experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model) + replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) + + # experience of all ranks should be the same + for _ in range(2): + data = get_data(EXPERINCE_BATCH_SIZE) + assert gather_and_equal(data['input_ids']) + assert gather_and_equal(data['attention_mask']) + experience = experience_maker.make_experience(**data, + do_sample=True, + max_length=16, + eos_token_id=50256, + pad_token_id=50256) + assert gather_and_equal(experience.sequences) + assert gather_and_equal(experience.action_log_probs) + assert gather_and_equal(experience.values) + assert gather_and_equal(experience.reward) + assert gather_and_equal(experience.advantages) + assert gather_and_equal(experience.action_mask) + assert gather_and_equal(experience.attention_mask) + replay_buffer.append(experience) + + # replay buffer's data should be the same + buffer_size = torch.tensor([len(replay_buffer)], device='cuda') + assert gather_and_equal(buffer_size) + for item in replay_buffer.items: + assert gather_and_equal(item.sequences) + assert gather_and_equal(item.action_log_probs) + assert gather_and_equal(item.values) + assert gather_and_equal(item.reward) + assert gather_and_equal(item.advantages) + assert gather_and_equal(item.action_mask) + assert gather_and_equal(item.attention_mask) + + # dataloader of each rank should have the same size and different batch + dataloader = strategy.setup_dataloader(replay_buffer) + dataloader_size = torch.tensor([len(dataloader)], device='cuda') + assert gather_and_equal(dataloader_size) + for experience in dataloader: + assert not gather_and_equal(experience.sequences) + assert not gather_and_equal(experience.action_log_probs) + assert not gather_and_equal(experience.values) + assert not gather_and_equal(experience.reward) + assert not gather_and_equal(experience.advantages) + # action mask and attention mask may be same + + +def run_dist(rank, world_size, port, strategy): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + run_test_data(strategy) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('strategy', ['ddp', 'colossalai']) +@rerun_if_address_is_in_use() +def test_data(world_size, strategy): + run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_data(2, 'colossalai') diff --git a/applications/ChatGPT/version.txt b/applications/ChatGPT/version.txt new file mode 100644 index 000000000..6e8bf73aa --- /dev/null +++ b/applications/ChatGPT/version.txt @@ -0,0 +1 @@ +0.1.0 From 8408c852a69036de73984f981b81874223ae8835 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Tue, 14 Feb 2023 22:48:15 +0800 Subject: [PATCH 09/30] [app] fix ChatGPT requirements (#2704) --- applications/ChatGPT/{requirements => }/requirements-test.txt | 0 applications/ChatGPT/{requirements => }/requirements.txt | 0 applications/ChatGPT/setup.py | 3 +-- 3 files changed, 1 insertion(+), 2 deletions(-) rename applications/ChatGPT/{requirements => }/requirements-test.txt (100%) rename applications/ChatGPT/{requirements => }/requirements.txt (100%) diff --git a/applications/ChatGPT/requirements/requirements-test.txt b/applications/ChatGPT/requirements-test.txt similarity index 100% rename from applications/ChatGPT/requirements/requirements-test.txt rename to applications/ChatGPT/requirements-test.txt diff --git a/applications/ChatGPT/requirements/requirements.txt b/applications/ChatGPT/requirements.txt similarity index 100% rename from applications/ChatGPT/requirements/requirements.txt rename to applications/ChatGPT/requirements.txt diff --git a/applications/ChatGPT/setup.py b/applications/ChatGPT/setup.py index f9607190a..deec10e0c 100644 --- a/applications/ChatGPT/setup.py +++ b/applications/ChatGPT/setup.py @@ -22,7 +22,6 @@ setup( packages=find_packages(exclude=( 'tests', 'benchmarks', - 'requirements', '*.egg-info', )), description='A RLFH implementation (ChatGPT) powered by ColossalAI', @@ -30,7 +29,7 @@ setup( long_description_content_type='text/markdown', license='Apache Software License 2.0', url='https://github.com/hpcaitech/ChatGPT', - install_requires=fetch_requirements('requirements/requirements.txt'), + install_requires=fetch_requirements('requirements.txt'), python_requires='>=3.6', classifiers=[ 'Programming Language :: Python :: 3', From 6a8cd687e3bc7da424390a508734976287f4260f Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Tue, 14 Feb 2023 22:48:30 +0800 Subject: [PATCH 10/30] [doc] add ChatGPT (#2703) --- README-zh-Hans.md | 28 ++++++++++++++++++++++++++-- README.md | 29 +++++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/README-zh-Hans.md b/README-zh-Hans.md index 4b0ba9c42..e16db47f9 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -3,7 +3,7 @@ [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/) - Colossal-AI: 一个面向大模型时代的通用深度学习系统 + Colossal-AI: 让AI大模型更低成本、方便易用、高效扩展

论文 | 文档 | @@ -23,10 +23,10 @@ ## 新闻 +* [2023/02] [Open source solution replicates ChatGPT training process! Ready to go with only 1.6GB GPU memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) * [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0) * [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) * [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding) -* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for) * [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the) @@ -64,6 +64,7 @@
  • Colossal-AI 成功案例 @@ -209,6 +210,29 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

    (返回顶端)

    ## Colossal-AI 成功案例 +### ChatGPT +低成本复现[ChatGPT](https://openai.com/blog/chatgpt/)完整流程 [[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ChatGPT) [[博客]](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) +

    + +

    + +- 最高可提升单机训练速度7.73倍,单卡推理速度1.42倍 + +

    + +

    + +- 单卡模型容量最多提升10.3倍 +- 最小demo训练流程最低仅需1.62GB显存 (任意消费级GPU) + +

    + +

    + +- 提升单卡的微调模型容量3.7倍 +- 同时保持高速运行 + +

    (back to top)

    ### AIGC 加速AIGC(AI内容生成)模型,如[Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) 和 [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion) diff --git a/README.md b/README.md index 703e3f3bf..e4ffca890 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/) - Colossal-AI: A Unified Deep Learning System for Big Model Era + Colossal-AI: Make big AI models cheaper, easier, and scalable

    Paper | Documentation | @@ -24,10 +24,10 @@ ## Latest News +* [2023/02] [Open source solution replicates ChatGPT training process! Ready to go with only 1.6GB GPU memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) * [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0) * [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) * [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding) -* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for) * [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the) ## Table of Contents @@ -64,6 +64,7 @@
  • Colossal-AI for Real World Applications @@ -211,6 +212,30 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt

    (back to top)

    ## Colossal-AI in the Real World +### ChatGPT +A low-cost [ChatGPT](https://openai.com/blog/chatgpt/) equivalent implementation process. [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ChatGPT) [[blog]](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) +

    + +

    + +- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference + +

    + +

    + +- Up to 10.3x growth in model capacity on one GPU +- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU) + +

    + +

    + +- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU +- Keep in a sufficiently high running speed + +

    (back to top)

    + ### AIGC Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion). From 71deddc87f8d8d108fc5198708873b81aefa11b0 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Tue, 14 Feb 2023 22:56:15 +0800 Subject: [PATCH 11/30] [doc] resize figure (#2705) * [doc] resize figure * [doc] resize figure --- README-zh-Hans.md | 4 ++-- README.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README-zh-Hans.md b/README-zh-Hans.md index e16db47f9..18623d67a 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -219,14 +219,14 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 - 最高可提升单机训练速度7.73倍,单卡推理速度1.42倍

    - +

    - 单卡模型容量最多提升10.3倍 - 最小demo训练流程最低仅需1.62GB显存 (任意消费级GPU)

    - +

    - 提升单卡的微调模型容量3.7倍 diff --git a/README.md b/README.md index e4ffca890..20a5f2606 100644 --- a/README.md +++ b/README.md @@ -221,14 +221,14 @@ A low-cost [ChatGPT](https://openai.com/blog/chatgpt/) equivalent implementation - Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference

    - +

    - Up to 10.3x growth in model capacity on one GPU - A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU)

    - +

    - Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU From 94f000515b3f5700934072c39890f45cf419eebc Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Tue, 14 Feb 2023 23:07:30 +0800 Subject: [PATCH 12/30] [doc] add Quick Preview (#2706) --- applications/ChatGPT/README.md | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/applications/ChatGPT/README.md b/applications/ChatGPT/README.md index dce59ad4b..43085f3ab 100644 --- a/applications/ChatGPT/README.md +++ b/applications/ChatGPT/README.md @@ -1,6 +1,6 @@ -# RLHF - ColossalAI +# RLHF - Colossal-AI -Implementation of RLHF (Reinforcement Learning with Human Feedback) powered by ColossalAI. It supports distributed training and offloading, which can fit extremly large models. +Implementation of RLHF (Reinforcement Learning with Human Feedback) powered by Colossal-AI. It supports distributed training and offloading, which can fit extremly large models. More details can be found in the [blog](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt).

    @@ -60,6 +60,27 @@ We also support training reward model with true-world data. See `examples/train_ - [ ] integrate with Ray - [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL) +## Quick Preview +

    + +

    + +- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference + +

    + +

    + +- Up to 10.3x growth in model capacity on one GPU +- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU) + +

    + +

    + +- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU +- Keep in a sufficiently high running speed + ## Citations ```bibtex From d701ef81b1ef59b7e3d09106a9482b9d9e19b2dd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 15 Feb 2023 09:39:44 +0800 Subject: [PATCH 13/30] Automated submodule synchronization (#2707) Co-authored-by: github-actions --- examples/tutorial/fastfold/FastFold | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tutorial/fastfold/FastFold b/examples/tutorial/fastfold/FastFold index 0188361b6..f05e71298 160000 --- a/examples/tutorial/fastfold/FastFold +++ b/examples/tutorial/fastfold/FastFold @@ -1 +1 @@ -Subproject commit 0188361b6e2b46bca61d37af5674eacf7ca9947f +Subproject commit f05e712982aeba6a32a9b3d1ee4dee6492426cec From 7fa6be49d2ac1eae2eda60f150597f0d3998ddf7 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 15 Feb 2023 09:43:29 +0800 Subject: [PATCH 14/30] [autoparallel] test compatibility for gemini and auto parallel (#2700) --- .../passes/runtime_preparation_pass.py | 10 +- .../test_compatibility_with_ddp.py | 98 ++++++++++++++++ .../test_compatibility_with_gemini.py | 108 ++++++++++++++++++ 3 files changed, 212 insertions(+), 4 deletions(-) create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 897602ce1..ecf3f1f18 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -377,8 +377,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o # TODO: build a ColoParamter class to manager the distributed parameters # we could use .data here, because all the operations just happen before the real training # loop, so we don't need to track these operations in the autograd graph. - param.data = shape_consistency_manager.apply_for_autoparallel_runtime( - param.data, param.sharding_spec, target_sharding_spec).detach().clone() + param = torch.nn.Parameter( + shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, + target_sharding_spec).detach().clone()) setattr(target_module, name, param) comm_actions = node.best_strategy.communication_actions @@ -432,8 +433,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o # TODO: build a ColoParamter class to manager the distributed parameters # we could use .data here, because all the operations just happen before the real training # loop, so we don't need to track these operations in the autograd graph. - target.data = shape_consistency_manager.apply_for_autoparallel_runtime( - target.data, target.sharding_spec, target_sharding_spec).detach().clone() + target = torch.nn.Parameter( + shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec, + target_sharding_spec).detach().clone()) assert hasattr(target_module, atoms[-1]) setattr(target_module, atoms[-1], target) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py new file mode 100644 index 000000000..365981f10 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -0,0 +1,98 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + + +class MLP(torch.nn.Module): + + def __init__(self, in_features): + super().__init__() + self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) + self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False) + + def forward(self, x): + x = self.linear_1(x) + x = self.linear_2(x) + + return x + + +def check_compatibility_with_ddp(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = MLP(4).cuda() + input = torch.rand(4, 4).cuda() + output_compare = model(input) + loss_compare = output_compare.sum() + loss_compare.backward() + grad_compare = copy.deepcopy(model.linear_1.weight.grad) + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + meta_args = {'x': torch.rand(4, 4).to('meta')} + gm, solution = initialize_model(model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference='tp', + shard_option='shard_last_axis') + + msg = '| TP strategy combination chosen by auto-parallel solver |' + msg_length = len(msg) + if rank == 0: + print('=' * msg_length) + print(msg) + print('=' * msg_length) + for strategy in solution: + print(strategy) + print('=' * msg_length) + + dp_process_group = None + for (ranks, process_group_handle) in device_mesh.process_groups_dict[0]: + if rank in ranks: + dp_process_group = process_group_handle + assert dp_process_group is not None + gm = DDP(gm, process_group=dp_process_group) + output = gm(input) + + assert_close(output, output_compare) + print(f'output on rank{rank} is correct') + loss = output.sum() + + loss.backward() + + if rank in (0, 2): + assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 0, 8)) + + if rank in (1, 3): + assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 8, 8)) + + print(f'gradient on rank{rank} is correct') + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_compatibility_with_ddp(): + world_size = 4 + run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_compatibility_with_ddp() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py new file mode 100644 index 000000000..b4080c545 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -0,0 +1,108 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.tensor.process_group import ProcessGroup +from colossalai.testing import assert_close, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx + + +class MLP(torch.nn.Module): + + def __init__(self, in_features): + super().__init__() + self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) + self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False) + + def forward(self, x): + x = self.linear_1(x) + x = self.linear_2(x) + + return x + + +def check_auto_parallel_with_gemini(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = MLP(4).half().cuda() + + input = torch.rand(4, 4).half().cuda() + output_compare = model(input) + loss_compare = output_compare.sum() + loss_compare.backward() + grad_compare = copy.deepcopy(model.linear_1.weight.grad) + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + meta_args = {'x': torch.rand(4, 4).half().to('meta')} + gm, solution = initialize_model(model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference='tp', + shard_option='shard_last_axis') + + if rank == 0: + msg = '| TP strategy combination chosen by auto-parallel solver |' + msg_length = len(msg) + print('=' * msg_length) + print(msg) + print('=' * msg_length) + for strategy in solution: + print(strategy) + print('=' * msg_length) + + dp_process_group = ProcessGroup(rank=rank, ranks=[0, 1, 2, 3], tp_degree=2, dp_degree=2) + gemini_config = dict(strict_ddp_mode=False, + device=get_current_device(), + placement_policy='cpu', + pin_memory=True, + search_range_mb=128) + + post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group) + gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) + optimizer = HybridAdam(gm.parameters(), betas=(0, 0)) + optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1) + output = gm(input) + assert_close(output, output_compare) + print(f'output on rank{rank} is correct') + loss = output.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + + if rank in (0, 2): + assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 0, 8).flatten()) + + if rank in (1, 3): + assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 8, 8).flatten()) + + print(f'gradient on rank{rank} is correct') + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_auto_parallel_with_gemini(): + world_size = 4 + run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_auto_parallel_with_gemini() From 0b2a738393bbc4b37efd5dfe9e60f33587357749 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 15 Feb 2023 09:54:32 +0800 Subject: [PATCH 15/30] [autoparallel] remove deprecated codes (#2664) --- .../tensor_shard/deprecated/__init__.py | 6 - .../tensor_shard/deprecated/_utils.py | 142 ---- .../tensor_shard/deprecated/constants.py | 83 -- .../tensor_shard/deprecated/cost_graph.py | 172 ---- .../tensor_shard/deprecated/graph_analysis.py | 163 ---- .../deprecated/op_handler/__init__.py | 15 - .../op_handler/batch_norm_handler.py | 492 ------------ .../deprecated/op_handler/bcast_op_handler.py | 552 ------------- .../deprecated/op_handler/conv_handler.py | 609 -------------- .../deprecated/op_handler/dot_handler.py | 756 ------------------ .../op_handler/embedding_handler.py | 179 ----- .../op_handler/layer_norm_handler.py | 241 ------ .../deprecated/op_handler/operator_handler.py | 149 ---- .../deprecated/op_handler/reshape_handler.py | 89 --- .../op_handler/strategy_generator.py | 45 -- .../op_handler/unary_elementwise_handler.py | 88 -- .../deprecated/op_handler/where_handler.py | 186 ----- .../tensor_shard/deprecated/options.py | 11 - .../deprecated/sharding_strategy.py | 91 --- .../tensor_shard/deprecated/solver.py | 469 ----------- .../deprecated/strategies_constructor.py | 426 ---------- .../test_deprecated_cost_graph.py | 96 --- .../test_deprecated_batch_norm_handler.py | 118 --- .../test_deprecated_bcast_handler.py | 75 -- .../test_deprecated_bcast_matmul.py | 54 -- .../test_deprecated_conv_handler.py | 90 --- .../test_deprecated_dot_handler.py | 83 -- .../test_deprecated_layer_norm_handler.py | 70 -- .../test_deprecated_reshape_handler.py | 59 -- .../test_deprecated_where_handler.py | 66 -- .../test_deprecated_shape_consistency_pass.py | 86 -- .../test_deprecated/test_deprecated_solver.py | 79 -- .../test_deprecated_solver_with_gpt.py | 81 -- .../test_deprecated_solver_with_mlp.py | 94 --- .../test_deprecated_strategies_constructor.py | 103 --- 35 files changed, 6118 deletions(-) delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/__init__.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/_utils.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/constants.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/options.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/solver.py delete mode 100644 colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py b/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py deleted file mode 100644 index bd47f2adf..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .cost_graph import CostGraph -from .graph_analysis import GraphAnalyser -from .options import SolverOptions -from .sharding_strategy import ShardingStrategy, StrategiesVector -from .solver import Solver -from .strategies_constructor import StrategiesConstructor diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py b/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py deleted file mode 100644 index d6af7ad57..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py +++ /dev/null @@ -1,142 +0,0 @@ -import functools -import operator -import warnings -from functools import reduce -from typing import Dict, List, Optional, Union - -import torch -from torch.fx.node import Node - -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec - -from .constants import INFINITY_COST - - -def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, - dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: - """ - Generate the sharding spec of the tensor based on the given dim_partition_dict. - - - Args: - input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node. - device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. - dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding. - """ - - if isinstance(input_, Node): - assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data' - meta_tensor = input_._meta_data - assert meta_tensor is not None, "The given node's _meta_data attribute is None" - shape = meta_tensor.shape - elif isinstance(input_, torch.Tensor): - shape = input_.shape - else: - raise TypeError( - f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' - ) - for dim_index, sharding_index_list in dim_partition_dict.items(): - sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] - sharding_size = reduce(operator.mul, sharding_list, 1) - assert shape[ - dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' - - sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) - return sharding_spec - - -def generate_resharding_costs(nodes: List[Node], - sharding_specs: List[ShardingSpec], - count_backward: Optional[bool] = True, - dtype: Optional[torch.dtype] = None, - index=None): - ''' - Compute the resharding costs with this specific strategy. - - Argument: - nodes (List[Node]): a list of nodes - sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. - count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. - dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. - ''' - # The resharding_cost of weight is counted due to sharing weight cases. - resharding_costs = {} - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - - # shape consistency manager is a singleton class - shape_consistency_manager = ShapeConsistencyManager() - - for input_node, input_spec in zip(nodes, sharding_specs): - resharding_costs[input_node] = [] - for strategy in input_node.strategies_vector: - input_sharding_spec = strategy.output_sharding_spec - if not isinstance(input_sharding_spec, ShardingSpec): - assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.' - input_sharding_spec = input_sharding_spec[index] - assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' - try: - # compute the resharding cost - _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( - input_sharding_spec, input_spec) - - # we need multiply the size of elem dtype to get correct communication cost - resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes - except AssertionError as e: - warnings.warn(f'{e}') - resharding_cost = INFINITY_COST - resharding_costs[input_node].append(resharding_cost) - return resharding_costs - - -def ignore_sharding_exception(func): - """ - A function wrapper which executes the function with a specified seed. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): - try: - rst = func(*args, **kwargs) - return rst - except AssertionError as e: - warnings.warn(f'{e}') - - return wrapper - - -def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size): - dim_partition_list = [] - # enumerate all the 2D sharding cases - for i in range(dim_size): - for j in range(i + 1, dim_size): - dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]} - dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]} - dim_partition_list.append(dim_partition_dict_0) - dim_partition_list.append(dim_partition_dict_1) - for i in range(dim_size): - dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]} - dim_partition_list.append(dim_partition_dict_flatten) - - return dim_partition_list - - -def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size): - dim_partition_list = [] - # enumerate all the 1D sharding cases - for i in range(dim_size): - dim_partition_dict_0 = {i: [mesh_dim_0]} - dim_partition_list.append(dim_partition_dict_0) - - return dim_partition_list - - -def generate_sharding_size(dim_partition_dict, device_mesh): - total_sharding_size = 1 - for mesh_dim_list in dim_partition_dict.values(): - mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list] - sharding_size = reduce(operator.mul, mesh_dim_sharding_size) - total_sharding_size *= sharding_size - - return total_sharding_size diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/constants.py b/colossalai/auto_parallel/tensor_shard/deprecated/constants.py deleted file mode 100644 index 91c20d343..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/constants.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import operator - -__all__ = [ - 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', - 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', - 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST' -] - -ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] -ELEMENTWISE_FUNC_OP = [ - torch.abs, - torch.cos, - torch.exp, - operator.neg, - torch.multiply, - torch.nn.functional.relu, - torch.nn.functional.dropout, - # softmax should not be here - torch.nn.functional.softmax -] -ELEMENTWISE_METHOD_OP = [ - torch.Tensor.to, - torch.Tensor.type, - # TODO: contiguous maybe need some extra processes. - torch.Tensor.contiguous -] -RESHAPE_FUNC_OP = [torch.flatten, torch.reshape] -RESHAPE_METHOD_OP = [ - torch.Tensor.view, - torch.Tensor.unsqueeze, - torch.Tensor.split, - torch.Tensor.permute, - torch.Tensor.transpose, -] -BCAST_FUNC_OP = [ - torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, - operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh -] -CONV_MODULE_OP = [ - torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, - torch.nn.ConvTranspose3d -] -CONV_FUNC_OP = [ - torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d -] -EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding] -LINEAR_MODULE_OP = [torch.nn.Linear] -LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] -BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm] -LAYERNORM_MODULE_OP = [torch.nn.LayerNorm] -POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d] -NON_PARAM_FUNC_OP = [ - torch.flatten, - torch.reshape, - torch.abs, - torch.cos, - torch.exp, - operator.neg, - torch.multiply, - torch.nn.functional.relu, - torch.nn.functional.dropout, - torch.flatten, - torch.where, - operator.pow, - torch.pow, - torch.tanh, - torch.add, - torch.sub, - torch.mul, - torch.div, - torch.floor_divide, - torch.true_divide, - operator.add, - operator.sub, - operator.mul, - operator.floordiv, - operator.truediv, - # softmax should not be here - torch.nn.functional.softmax -] - -INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py b/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py deleted file mode 100644 index 239d02115..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py +++ /dev/null @@ -1,172 +0,0 @@ -from typing import List -import math -from torch.fx.node import Node -from .constants import INFINITY_COST - - -class CostGraph: - ''' - A graph data structure to simplify the edge cost graph. It has two main functions: - 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in - CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. - 2. To reduce the searching space, we merge computationally-trivial operators, such as - element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will - be given by the StrategiesVector depending on the type of target node and following nodes. - - Argument: - leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. - simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) - ''' - - def __init__(self, leaf_strategies, simplify=True): - self.leaf_strategies = leaf_strategies - self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] - # stores number of strategies in each node - self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} - # extra_node_costs will store the extra costs introduced by merging nodes - self.extra_node_costs = {} - self.following_dict = {} - self.simplify = simplify - self._build_cost_graph() - - def _remove_invalid_node(self, node, attr_name): - remove_list = [] - target_node_list = getattr(node, attr_name, []) - for target_node in target_node_list: - if target_node not in self.nodes: - remove_list.append(target_node) - for element in remove_list: - target_node_list.remove(element) - - def _build_cost_graph(self): - ''' - This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be - set to node. - ''' - self.edge_costs = {} - if self.simplify: - self.merge_pair = [] - for strategies_vector in self.leaf_strategies: - # build edge_cost - dst_node = strategies_vector.node - for src_node in strategies_vector.predecessor_nodes: - if src_node not in self.nodes: - continue - node_pair = (src_node, dst_node) - # src_index = strategies_vector.predecessor_nodes.index(src_node) - edge_cost = {} - for i in range(len(strategies_vector)): - for j in range(len(src_node.strategies_vector)): - edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j] - self.edge_costs[node_pair] = edge_cost - # add parents and children attribute to node - setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) - setattr(dst_node, 'children', strategies_vector.successor_nodes) - self._remove_invalid_node(dst_node, 'parents') - self._remove_invalid_node(dst_node, 'children') - - if self.simplify and strategies_vector.check_merge(): - for followed_node in strategies_vector.predecessor_nodes: - self.merge_pair.append((followed_node, dst_node)) - - def get_edge_cost(self, src_node, dst_node): - return self.edge_costs[(src_node, dst_node)] - - def merge_node(self, src_node, dst_node): - ''' - To merge dst_node into src_node, we need to do it in following steps: - - 1. For each strategy in dst_node, we need to pick an appropriate strategy - of src_node to merge, it is important because the logical resharding costs - between the parents node of src_node and merged node depend on the src_node - strategies dispatching. For example, for the graph 0->1->2, after merging node 1 - into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] - x represents the picking strategy of node 1 merged into node 2 strategy 0. - - 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs - contains two parts, one is resharding costs between src_node strategy and dst_node strategy, - another is the origin extra costs in src_node strategy. - - 3. Build connections between new node pairs, and remove the src_node after all consumer nodes - detached from it. - - Argument: - src_node(Node): The node will be merged into dst_node. - dst_node(Node): The node to integrate src_node. - ''' - src_node_index = dst_node.parents.index(src_node) - # build merge_map - merge_map = {} - for src_index, strategy in enumerate(src_node.strategies_vector): - min_cost = INFINITY_COST - lowest_cost_index = -1 - for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): - resharding_cost = dst_strategy.resharding_costs[src_node][src_index] - if resharding_cost <= min_cost: - min_cost = resharding_cost - lowest_cost_index = dst_index - merge_map[src_index] = lowest_cost_index - - # extra_node_cost for src node - self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node] - for src_index, strategy in enumerate(src_node.strategies_vector): - target_strate_index = merge_map[src_index] - target_strategy = dst_node.strategies_vector[target_strate_index] - self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index] - if dst_node in self.extra_node_costs: - self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index] - - # add new node pair to cost graph - for child_node in dst_node.children: - new_node_pair = (src_node, child_node) - old_node_pair = (dst_node, child_node) - if new_node_pair in self.edge_costs: - continue - edge_cost = {} - for i in range(self.node_lens[src_node]): - for j in range(self.node_lens[child_node]): - dst_strate_index = merge_map[i] - # dst_strategy = dst_node.strategies_vector[dst_strate_index] - edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] - if new_node_pair not in self.edge_costs: - self.edge_costs[new_node_pair] = edge_cost - else: - # we should accumulate the resharding costs if args of child node contain - # both src node and dst node. - for index_pair, resharding_cost in self.edge_costs[new_node_pair]: - self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair] - - # connect src node and children of dst node - dst_node.parents.remove(src_node) - src_node.children.remove(dst_node) - self.edge_costs.pop((src_node, dst_node)) - for child_node in dst_node.children: - if child_node not in src_node.children: - src_node.children.append(child_node) - if src_node not in child_node.parents: - child_node.parents.append(src_node) - # remove dst node from cost graph when dst node has no producer. - if len(dst_node.parents) == 0: - child_node.parents.remove(dst_node) - node_pair = (dst_node, child_node) - self.edge_costs.pop(node_pair) - if len(dst_node.parents) == 0: - self.following_dict[dst_node] = src_node - dst_node.children = [] - - def _reindexing_src(self, src): - if src not in self.following_dict: - return src - return self._reindexing_src(self.following_dict[src]) - - def simplify_graph(self): - if not self.simplify: - return - self.merge_pair.reverse() - for (src_node, dst_node) in self.merge_pair: - self.merge_node(src_node, dst_node) - self.merge_pair.reverse() - reindexing_following_dict = {} - for dst, src in self.following_dict.items(): - reindexing_following_dict[dst] = self._reindexing_src(src) - self.following_dict = reindexing_following_dict diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py deleted file mode 100644 index 831e7eadd..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py +++ /dev/null @@ -1,163 +0,0 @@ -from dataclasses import dataclass -from torch.fx.node import Node -from torch.fx.graph import Graph -from torch.fx.graph_module import GraphModule -from collections import OrderedDict as ODict -from typing import List, OrderedDict, Union, Any -from colossalai.fx.passes.utils import get_node_module - -__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] - - -@dataclass -class LiveVariable: - """ - LiveVariable is a data structure to store the meta information of a variable for liveness analysis. - """ - name: str - node: Node - is_inplace: bool - - -class LiveVariableVector(list): - """ - LiveVariableVector is a data structure to store the list of LiveVariable objects. - """ - - def exists(self, name) -> bool: - """ - Check if a variable has already existed in the current list by name. - """ - for var in self: - if name == var.name: - return True - return False - - def get(self, name) -> LiveVariable: - for var in self: - if name == var.name: - return var - raise KeyError(f"Variable {name} is not found") - - def copy(self) -> "LiveVariableVector": - """ - Create a copy of this vector - """ - vector = LiveVariableVector() - for var in self: - vector.append(var) - return vector - - -@dataclass -class LiveStage: - """ - LiveStage is a data structure to record the living variables at this current node. - """ - name: str - node: Node - all_live_vars: LiveVariableVector - unique_live_vars: LiveVariableVector - - -class GraphAnalyser: - - def __init__(self, gm: GraphModule): - self._gm = gm - self._graph = gm.graph - - @property - def gm(self) -> GraphModule: - """ - Return the GraphModule object associated with this analyser. - """ - return self._gm - - @property - def graph(self) -> Graph: - """ - Return the Graph object associated with this analyser. - """ - return self._graph - - def liveness_analysis(self) -> List[LiveStage]: - """ - Analyse the graph to obtain the variable liveness information. This function returns - an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. - """ - compute_nodes = self.graph.nodes - liveness_list = [] - - # checked: record all variables created since the first stage - # all: record the live variables only exist until the current stage. - # this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage. - # unique: record the unique live variables only exist until the current stage. - # this is different from `all list` as some variables are duplicated. - checked_variables = LiveVariableVector() - all_live_variables = LiveVariableVector() - unique_live_vars = LiveVariableVector() - - for idx, node in enumerate(compute_nodes): - ############################# - # find new living variables # - ############################# - # detect whether the current op is an in-place op - # if it is an in-place op, we would deem it as a duplciate var - is_inplace = False - if node.op == 'call_function': - # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) - if node.kwargs.get('inplace', False): - is_inplace = True - elif node.op == 'call_module': - # to check if this is an inplace op such as torch.nn.Relu(inplace=True) - module = get_node_module(node) - if getattr(module, 'inplace', False): - is_inplace = True - - # add the output var - meta = getattr(node, '_meta_data', None) - live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace) - if not is_inplace: - unique_live_vars.append(live_var) - checked_variables.append(live_var) - all_live_variables.append(live_var) - - # check if any input is not checked yet - for arg in node.args: - if not isinstance(arg, Node): - continue - arg_name = arg.name - if not checked_variables.exists(arg_name): - live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False) - all_live_variables.append(live_var_from_arg) - checked_variables.append(live_var_from_arg) - unique_live_vars.append(live_var_from_arg) - - # TODO: add the logic to remove live variables - # this should be completed if we are able to trace the backward compute graph - - # add this stage to liveness dict - stage = LiveStage(name=node.name, - node=node, - all_live_vars=all_live_variables.copy(), - unique_live_vars=unique_live_vars.copy()) - # if a LiveStage is covered by another LiveStage, we just keep the larger one. - replace = False - for index, prev_stage in enumerate(liveness_list): - all_covered = True - for ele in prev_stage.unique_live_vars: - if ele not in stage.unique_live_vars: - all_covered = False - break - if all_covered: - replace = True - break - if replace: - liveness_list[index] = stage - else: - liveness_list.append(stage) - - return liveness_list - - def get_alias_set(self): - pass diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py deleted file mode 100644 index 723e1bcf9..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .batch_norm_handler import BatchNormHandler -from .bcast_op_handler import BcastOpHandler -from .conv_handler import ConvHandler -from .dot_handler import DotHandler -from .embedding_handler import EmbeddingHandler -from .layer_norm_handler import LayerNormHandler -from .operator_handler import OperatorHandler -from .reshape_handler import ReshapeHandler -from .unary_elementwise_handler import UnaryElementwiseHandler -from .where_handler import WhereHandler - -__all__ = [ - 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', - 'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler', 'LayerNormHandler' -] diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py deleted file mode 100644 index 519436270..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py +++ /dev/null @@ -1,492 +0,0 @@ -import operator -from functools import reduce - -import torch -from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ - ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) - -from .operator_handler import OperatorHandler - -__all__ = ['BatchNormHandler'] - - -class BatchNormHandler(OperatorHandler): - """ - A OperatorHandler which deals with the sharding strategies of normalization. - - To keep the math consistency, there are two way to do BatchNorm if the input - shards on batch dimension: - 1. We gather the input partitions through batch dimension, then do the normal BatchNorm. - 2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help - us to keep the computing correctness. - In this handler, both methods will be considered. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.input_data = self.predecessor_node[0]._meta_data - self.weight = self.module_named_parameters['weight'] - self.bias = self.module_named_parameters['bias'] - self.output_data = self.node._meta_data - self._sanity_check() - - def _sanity_check(self): - ''' - In sanity check, we need make sure the input data having correct dimension size. - For BatchNorm1d, the dim of input data should be 3([N, C, L]). - For BatchNorm2d, the dim of input data should be 4([N, C, H, W]). - For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - assert self.input_data.dim() in (3, 4, - 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' - - def _generate_compute_cost(self, bs, channel_in): - ''' - Compute the computation cost per device with this specific strategy. - - Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. - - Argument: - bs(int): Batch size of the input data. - channel_in(int): The channel dimension of input data. - - Return: - compute_cost(float): Computation cost per device with this specific strategy - ''' - # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. - # TODO: a constant coefficient need to be added. - # 1D: (L) * N * Cin - # 2D: (H * W) * N * Cin - # 3D: (H * W * D) * N * Cin - - input_size = self.input_data.shape[2:] - input_size_product = reduce(operator.mul, input_size, 1) - forward_compute_cost = input_size_product * bs * channel_in - backward_activation_compute_cost = input_size_product * bs * channel_in - backward_weight_compute_cost = input_size_product * bs * channel_in - backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost - compute_cost = forward_compute_cost + backward_compute_cost - return compute_cost - - def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): - ''' - Compute the memory cost per device with this specific strategy. - - Argument: - sharding_size_forward(int): The forward activation will be divided - into sharding_size_forward number partions. - sharding_size_backward_activation(int): The backward activation will - be divided into sharding_size_backward_activation number partions. - sharding_size_weight(int): The backward weight will be divided - into sharding_size_weight number partions. - - Return: - memory_cost(Tuple[float]): Memory cost per device with this - specific strategy, the first element of this tuple is forward - memory cost, and the second element of this tuple is backward - memory cost. - memory_cost_forward(float): Memory cost of forward activation per - device with this specific strategy. - memory_cost_backward_activation(float): Memory cost of backward activation - per device with this specific strategy. - ''' - # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel_output = self.output_data.numel() - numel_input = numel_output - numel_weight = self.weight.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - - # forward memory_cost - memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward - memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight - memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight - - # backward memory_cost - memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation - memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight - memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight - - # memory_cost pair - memory_cost = (memory_cost_forward, memory_cost_backward) - - return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation - - @ignore_sharding_exception - def split_input_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' - - dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {1: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] - channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] - compute_cost = self._generate_compute_cost(bs, channel_in) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_0] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_weight = self.device_mesh.shape[mesh_dim_0] - memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_weight) - - # This strategy do not need to do all_reduce operation - communication_cost = 0 - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - - self.strategies_vector.append(sharding_strategies) - - # shard the output batch dimension to get all possible sharding strategy from this basic strategy - new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' - - dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]} - new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - # the computation cost is all the same - new_compute_cost = compute_cost - - # the memory cost need to be recomputed - # compute the memroy cost of new strategy - new_sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_weight = self.device_mesh.shape[mesh_dim_0] - new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( - new_sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) - - # the communication cost need to count the sharding cost into this strategy - # compute the communication cost of new strategy - origin_communication_cost = communication_cost - tiny_shard_cost = 10 - new_forward_communication_cost = tiny_shard_cost - # we need to all gather the batch dimension for the basic strategy - new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, mesh_dim_1) - new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost - - sharding_strategies = ShardingStrategy(new_name, - output_sharding_spec=new_sharding_spec_for_output, - compute_cost=new_compute_cost, - communication_cost=new_communication_cost, - memory_cost=new_memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' - - dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] - channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] * - self.device_mesh.shape[mesh_dim_1]) - compute_cost = self._generate_compute_cost(bs, channel_in) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_weight) - - # This strategy do not need to do all_reduce operation - communication_cost = 0 - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def non_split(self, mesh_dim_0, mesh_dim_1): - name = f'RR = RR x R' - - dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] - channel_in = self.input_data.shape[1] - compute_cost = self._generate_compute_cost(bs, channel_in) - - # compute the memory cost of this strategy - sharding_size_forward = 1 - sharding_size_backward_activation = 1 - sharding_size_weight = 1 - memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_weight) - - # This strategy do not need to do all_reduce operation - communication_cost = 0 - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - - self.strategies_vector.append(sharding_strategies) - - def _construct_batch_sharding_strategies(mesh_dim_list, new_name): - dim_partition_dict_for_output = {0: mesh_dim_list} - new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # the computation cost is all the same - new_compute_cost = compute_cost - - # the memory cost need to be recomputed - new_sharding_size_input = 1 - for mesh_dim in mesh_dim_list: - new_sharding_size_input = new_sharding_size_input * self.device_mesh.shape[mesh_dim] - new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( - new_sharding_size_input, sharding_size_backward_activation, sharding_size_weight) - - # the communication cost need to count the sharding cost into this strategy - origin_communication_cost = communication_cost - tiny_shard_cost = 10 - new_forward_communication_cost = tiny_shard_cost - if len(mesh_dim_list) == 1: - new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, - mesh_dim_list[0]) - else: - new_backward_communication_cost = self.device_mesh.flatten_device_mesh.all_gather_cost( - memory_cost_backward_activation, 0) - new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost - - new_sharding_strategy = ShardingStrategy(new_name, - output_sharding_spec=new_sharding_spec_for_output, - compute_cost=new_compute_cost, - communication_cost=new_communication_cost, - memory_cost=new_memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, - sharding_spec_for_weight)) - - return new_sharding_strategy - - # shard the output batch dimension to get all possible sharding strategy from this basic strategy - # shard on mesh_dim_0 - new_name = f'S{mesh_dim_0}R = RR x R' - mesh_dim_list = [mesh_dim_0] - new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) - self.strategies_vector.append(new_sharding_strategy) - - # shard on mesh_dim_1 - new_name = f'S{mesh_dim_1}R = RR x R' - mesh_dim_list = [mesh_dim_1] - new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) - self.strategies_vector.append(new_sharding_strategy) - - # shard on mesh_dim_0, mesh_dim_1 - new_name = f'S{mesh_dim_0}{mesh_dim_1}R = RR x R' - mesh_dim_list = [mesh_dim_0, mesh_dim_1] - new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) - self.strategies_vector.append(new_sharding_strategy) - - @ignore_sharding_exception - def split_input_batch(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' - - dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] - channel_in = self.input_data.shape[1] - compute_cost = self._generate_compute_cost(bs, channel_in) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_0] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_weight = 1 - memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_weight) - - # the all reduce communication will happen during the sync bn computing. - communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' - - dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) - channel_in = self.input_data.shape[1] - compute_cost = self._generate_compute_cost(bs, channel_in) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_weight = 1 - memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_weight) - - # the all reduce communication will happen during the sync bn computing. - communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward_activation, 0) - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' - - dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] - channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1] - compute_cost = self._generate_compute_cost(bs, channel_in) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_weight = self.device_mesh.shape[mesh_dim_1] - memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_weight) - - # the all reduce communication will happen during the sync bn computing. - communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - - self.strategies_vector.append(sharding_strategies) - - def register_strategy(self) -> StrategiesVector: - ''' - Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. - - Example: - norm_handler = BatchNormHandler(node, strategies_vector, - self.shape_consistency_manager) - norm_handler.register_strategy() - for strategy in norm_handler.strategies_vector: - print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') - - Output: - RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0 - RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0 - RR = RR x R, computation_cost: 262144, memory_cost: 1048576 - RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0 - ''' - - # RS = RS x S and strategies based on it, such as - # SS = RS x S - self.split_input_channel(0, 1) - self.split_input_channel(1, 0) - - # RR = RR x R and strategies based on it, such as - # SR = SR x R - self.non_split(0, 1) - - # RS01 = RS01 x S01 - self.split_input_channel_1d(0, 1) - - # SR = SR x R WITH SYNC_BN - self.split_input_batch(0) - self.split_input_batch(1) - - # SS = SS x S WITH SYNC_BN - self.split_input_both_dim(0, 1) - self.split_input_both_dim(1, 0) - - # S01R = S01R x R WITH SYNC_BN - self.split_input_batch_1d(0, 1) - - return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py deleted file mode 100644 index 6ac6dce76..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py +++ /dev/null @@ -1,552 +0,0 @@ -import operator -import warnings -from copy import deepcopy -from functools import reduce -from typing import Dict, List - -import torch -from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding, - ignore_sharding_exception) -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec - -from .operator_handler import OperatorHandler - -__all__ = ['BcastOpHandler'] - - -class BcastOpHandler(OperatorHandler): - """ - An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add). - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert len(self.predecessor_node) == 2 - self.lhs_data = self.predecessor_node[0]._meta_data - self.rhs_data = self.predecessor_node[1]._meta_data - self.lhs = self.predecessor_node[0] - self.rhs = self.predecessor_node[1] - self.output_data = self.node._meta_data - - def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: - shape = list(input_.shape) - - # padding the shape to the same length as output_data - while len(shape) < self.output_data.dim(): - shape.insert(0, 1) - shape = torch.Size(shape) - - # if the sharding happens on a size one dimension, we should record it as R. - processed_dim_partition_dict = deepcopy(dim_partition_dict) - for dim_index, _ in dim_partition_dict.items(): - if shape[dim_index] == 1: - processed_dim_partition_dict.pop(dim_index) - for dim_index, sharding_index_list in processed_dim_partition_dict.items(): - sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] - sharding_size = reduce(operator.mul, sharding_list, 1) - assert shape[ - dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=shape, - dim_partition_dict=processed_dim_partition_dict) - - return sharding_spec - - def _generate_compute_cost(self, total_sharding_size): - lhs_matrix_shape = self.lhs_data.shape[-2:] - rhs_matrix_shape = self.rhs_data.shape[-2:] - batch_dimensions_shape = self.output_data.shape[:-2] - batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1) - compute_cost = reduce( - operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size - return compute_cost - - def _generate_resharding_costs(self, sharding_specs): - # The resharding_cost of weight is counted due to sharing weight cases. - dtype = self.node._meta_data.dtype - nodes = self.predecessor_node - resharding_costs = {} - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - - # shape consistency manager is a singleton class - shape_consistency_manager = ShapeConsistencyManager() - - for input_node, input_spec in zip(nodes, sharding_specs): - resharding_costs[input_node] = [] - for strategy in input_node.strategies_vector: - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' - # if the input shape is smaller than the target input, we will fill the input to the same length as target. - # Then, use the padded input sharding spec to compute the resharding cost. - if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape): - new_entire_shape = list(input_sharding_spec.entire_shape) - while len(new_entire_shape) < len(input_spec.entire_shape): - new_entire_shape.insert(0, 1) - new_entire_shape = torch.Size(new_entire_shape) - new_device_mesh = input_sharding_spec.device_mesh - new_dim_partition_dict = input_sharding_spec.dim_partition_dict - input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh, - entire_shape=new_entire_shape, - dim_partition_dict=new_dim_partition_dict) - - # compute the resharding cost - _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( - input_sharding_spec, input_spec) - - # we need multiply the size of elem dtype to get correct communication cost - resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes - resharding_costs[input_node].append(resharding_cost) - - return resharding_costs - - def _convert_partition_dict_to_sharding_spec(self, dim_partition_list): - - sharding_spec_list = [] - check_duplicated_list = [] - for output_dim_partition_dict in dim_partition_list: - try: - output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict) - except AssertionError as e: - warnings.warn(f'{e}') - break - sharding_seq = output_sharding_spec.sharding_sequence - if sharding_seq not in check_duplicated_list: - check_duplicated_list.append(sharding_seq) - sharding_spec_list.append(output_sharding_spec) - - return sharding_spec_list - - def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): - # use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity. - - output_dim_partition_list = [] - dim_size = self.output_data.dim() - # enumerate all the 2D sharding cases - sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) - output_dim_partition_list.extend(sharding_list_2d) - - # enumerate all the 1D sharding cases - sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) - output_dim_partition_list.extend(sharding_list_1d_on_dim_0) - sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) - output_dim_partition_list.extend(sharding_list_1d_on_dim_1) - - # add empty dict for fully replicated case - output_dim_partition_list.append({}) - output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list) - - return output_sharding_spec_list - - @ignore_sharding_exception - def _register_strategy(self, output_sharding_spec): - dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict - sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input) - sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input) - - name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' - dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) - - # compute the computation cost of this strategy - sharding_dims = [] - for mesh_dims in dim_partition_dict_for_output.values(): - for mesh_dim in mesh_dims: - sharding_dims.append(self.device_mesh.shape[mesh_dim]) - sharding_size = reduce(operator.mul, sharding_dims, 1) - memory_cost = self.output_data.numel() / sharding_size - compute_cost = memory_cost - communication_cost = 0 - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=output_sharding_spec, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) - - self.strategies_vector.append(sharding_strategies) - - ############################################## - #used to generate strategies for torch.matmul# - ############################################## - @ignore_sharding_exception - def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim): - # this dim partition dict only describes the batch dimensions, but in this scenario, - # matrix dimensions are fully replicated, so it do not need extra process. - sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_batch_dim) - sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_batch_dim) - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_batch_dim) - - name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) - - # compute the memory cost of this strategy - batch_sharding_dims = [] - for mesh_dims in dim_partition_dict_for_batch_dim.values(): - for mesh_dim in mesh_dims: - batch_sharding_dims.append(self.device_mesh.shape[mesh_dim]) - batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1) - # in this case, total_sharding_size is equal to the batch sharding size - memory_cost = self.output_data.numel() / batch_sharding_size - - # compute the computation cost of this strategy - compute_cost = self._generate_compute_cost(batch_sharding_size) - - # in this case, no communication takes place. - # TODO: add all-reduce cost if lhs or rhs is type of Parameters. - communication_cost = 0 - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): - # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] - # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. - # In this scenario, matrix dimensions will be sharded on 'i' dimension. - - # in this case, the matrix dimensions of lhs is sharded on 'i' dimension. - dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim) - dim_partition_dict_for_lhs.update({-2: mesh_dim_on_matrix}) - - # in this case, the matrix dimensions of rhs is fully replicated. - dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim) - - # in this case, the matrix dimensions of output is sharded on 'i' dimension. - - dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim) - dim_partition_dict_for_output.update({-2: mesh_dim_on_matrix}) - - # generate sharding specs - sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) - sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) - - # compute the memory cost of this strategy - total_sharding_dims = [] - - # append batch sharding dims - for mesh_dims in dim_partition_dict_for_batch_dim.values(): - for mesh_dim in mesh_dims: - total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) - - # append the sharding dims on matrix dimension - for mesh_dim in mesh_dim_on_matrix: - total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) - total_sharding_size = reduce(operator.mul, total_sharding_dims, 1) - - # in this case, output_data uses all the sharding dims. - memory_cost = self.output_data.numel() / total_sharding_size - compute_cost = self._generate_compute_cost(total_sharding_size) - - # TODO: add all-reduce cost if lhs or rhs is type of Parameters. - communication_cost = 0 - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): - # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] - # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. - # In this scenario, matrix dimensions will be sharded on 'k' dimension. - - # in this case, the matrix dimensions of lhs is sharded on 'k' dimension. - dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim) - dim_partition_dict_for_lhs.update({-1: mesh_dim_on_matrix}) - - # in this case, the matrix dimensions of rhs is sharded on 'k' dimension. - dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim) - dim_partition_dict_for_rhs.update({-2: mesh_dim_on_matrix}) - - # in this case, the matrix dimensions of output is fully replicated. - dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim) - - # generate sharding specs - sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) - sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) - - # compute the memory cost of this strategy - total_sharding_dims = [] - batch_sharding_dims = [] - # append batch sharding dims - for mesh_dims in dim_partition_dict_for_batch_dim.values(): - for mesh_dim in mesh_dims: - total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) - batch_sharding_dims.append(self.device_mesh.shape[mesh_dim]) - - # append the sharding dims on matrix dimension - for mesh_dim in mesh_dim_on_matrix: - total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) - batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1) - total_sharding_size = reduce(operator.mul, total_sharding_dims, 1) - - # in this case, output_data is fully replicated on matrix dimensions. - memory_cost = self.output_data.numel() / batch_sharding_size - - compute_cost = self._generate_compute_cost(total_sharding_size) - - # TODO: add all-reduce cost if lhs or rhs is type of Parameters. - # The communication takes place during forward activation computation. - if len(mesh_dim_on_matrix) == 1: - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0]) - else: - communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0) - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): - # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] - # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. - # In this scenario, matrix dimensions will be is sharded on 'j' dimension. - - # in this case, the matrix dimensions of lhs is fully replicated. - dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim) - - # in this case, the matrix dimensions of rhs is sharded on 'j' dimension. - dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim) - dim_partition_dict_for_rhs.update({-1: mesh_dim_on_matrix}) - - # in this case, the matrix dimensions of output is sharded on 'j' dimension. - dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim) - dim_partition_dict_for_output.update({-1: mesh_dim_on_matrix}) - - # generate sharding specs - sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) - sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) - - # compute the memory cost of this strategy - total_sharding_dims = [] - - # append batch sharding dims - for mesh_dims in dim_partition_dict_for_batch_dim.values(): - for mesh_dim in mesh_dims: - total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) - - # append the sharding dims on matrix dimension - for mesh_dim in mesh_dim_on_matrix: - total_sharding_dims.append(self.device_mesh.shape[mesh_dim]) - total_sharding_size = reduce(operator.mul, total_sharding_dims, 1) - - # in this case, output_data uses all the sharding dims. - memory_cost = self.output_data.numel() / total_sharding_size - compute_cost = self._generate_compute_cost(total_sharding_size) - - # TODO: add all-reduce cost if lhs or rhs is type of Parameters. - # The communication takes place during backward activation computation. - if len(mesh_dim_on_matrix) == 1: - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0]) - else: - communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0) - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) - - self.strategies_vector.append(sharding_strategies) - - def _registry_1d_strategies_for_matmul(self, dim_partition_dict, mesh_dim_list): - self._split_dim_i(dim_partition_dict, mesh_dim_list) - self._split_dim_k(dim_partition_dict, mesh_dim_list) - self._split_dim_j(dim_partition_dict, mesh_dim_list) - - @ignore_sharding_exception - def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): - dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]} - sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) - - dim_partition_dict_for_rhs = {-2: [mesh_dim_1]} - sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) - - dim_partition_dict_for_output = {-2: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) - - # compute the memory cost of this strategy - total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1) - output_sharding_size = reduce(operator.mul, self.output_data.shape, 1) - # in this case, output_data uses all the sharding dims. - memory_cost = self.output_data.numel() / output_sharding_size - compute_cost = self._generate_compute_cost(total_sharding_size) - - # TODO: add all-reduce cost if lhs or rhs is type of Parameters. - # The communication takes place during forward activation computation. - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): - dim_partition_dict_for_lhs = {-1: [mesh_dim_0]} - sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) - - dim_partition_dict_for_rhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]} - sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) - - dim_partition_dict_for_output = {-1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) - - # compute the memory cost of this strategy - total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1) - output_sharding_size = reduce(operator.mul, self.output_data.shape, 1) - # in this case, output_data uses all the sharding dims. - memory_cost = self.output_data.numel() / output_sharding_size - compute_cost = self._generate_compute_cost(total_sharding_size) - - # TODO: add all-reduce cost if lhs or rhs is type of Parameters. - # The communication takes place during forward and backward activation computation. - communication_cost_forward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0) - communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) - communication_cost = communication_cost_backward_activation + communication_cost_forward_activation - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): - dim_partition_dict_for_lhs = {-2: [mesh_dim_0]} - sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) - - dim_partition_dict_for_rhs = {-1: [mesh_dim_1]} - sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs) - - dim_partition_dict_for_output = {-2: [mesh_dim_0], -1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) - - # compute the memory cost of this strategy - total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1) - output_sharding_size = reduce(operator.mul, self.output_data.shape, 1) - # in this case, output_data uses all the sharding dims. - memory_cost = self.output_data.numel() / output_sharding_size - compute_cost = self._generate_compute_cost(total_sharding_size) - - # TODO: add all-reduce cost if lhs or rhs is type of Parameters. - # The communication takes place during backward activation computation. - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) - - self.strategies_vector.append(sharding_strategies) - - def _registry_2d_strategies_for_matmul(self): - self._split_lhs_space_both_contract(0, 1) - self._split_lhs_space_both_contract(1, 0) - self._split_rhs_space_both_contract(0, 1) - self._split_rhs_space_both_contract(1, 0) - self._split_lhs_space_rhs_space(0, 1) - self._split_lhs_space_rhs_space(1, 0) - - def register_strategy(self) -> StrategiesVector: - MESH_DIM_LIST = [0, 1] - if self.node.target != torch.matmul: - output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1]) - for output_sharding_spec in output_sharding_specs: - self._register_strategy(output_sharding_spec) - else: - # we only care about the non-computing dimensions, - # therefore, we omit the last two dimensions. - dim_size = self.output_data.dim() - 2 - - # Both device mesh axises are uesd on batch dimensions - dim_partition_dicts_2d = enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], dim_size) - for dim_partition_dict in dim_partition_dicts_2d: - self._registry_no_split_strategies_for_matmul(dim_partition_dict) - - # Only one device mesh axis is uesd on batch dimensions - for mesh_dim_index in [0, 1]: - dim_partition_dicts_1d = enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], dim_size) - for dim_partition_dict in dim_partition_dicts_1d: - self._registry_no_split_strategies_for_matmul(dim_partition_dict) - self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]]) - - # No device mesh axis is uesd on batch dimensions - dim_partition_dict_on_batch_dim = {} - self._registry_no_split_strategies_for_matmul(dim_partition_dict_on_batch_dim) - self._registry_1d_strategies_for_matmul(dim_partition_dict_on_batch_dim, MESH_DIM_LIST) - self._registry_2d_strategies_for_matmul() diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py deleted file mode 100644 index d8952040d..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py +++ /dev/null @@ -1,609 +0,0 @@ -import operator -import warnings -from functools import reduce - -import torch - -from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector - -from .operator_handler import OperatorHandler - -__all__ = ['ConvHandler'] - - -class ConvHandler(OperatorHandler): - """ - An OperatorHandler which deals with the sharding strategies of Convolution. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.input_data = self.predecessor_node[0]._meta_data - self.weight = self.module_named_parameters['weight'] - self.output_data = self.node._meta_data - self._sanity_check() - - def _sanity_check(self): - ''' - In sanity check, we need make sure the input data having correct dimension size. - For Conv1d, the dim of input data should be 3([N, C, L]). - For Conv2d, the dim of input data should be 4([N, C, H, W]). - For Conv3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - assert self.input_data.dim() in (3, 4, - 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' - - def _generate_compute_cost(self, bs, channel_in, channel_out): - ''' - Compute the computation cost per device with this specific strategy. - - Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. - - Argument: - bs(int): Batch size of the input data. - channel_in(int): The channel dimension of input data. - channel_out(int): The out channel of the conv weight. - - Return: - compute_cost(float): Computation cost per device with this specific strategy - ''' - # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. - # 1D: (L) * N * Cout * Cin * kernel - # 2D: (H * W) * N * Cout * Cin * kernel - # 3D: (H * W * D) * N * Cout * Cin * kernel - output_size = self.output_data.shape[2:] - output_size_product = reduce(operator.mul, output_size, 1) - input_size = self.input_data.shape[2:] - input_size_product = reduce(operator.mul, input_size, 1) - kernel_size = self.weight.shape[2:] - kernel_size_product = reduce(operator.mul, kernel_size, 1) - forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product - backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product - backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product - compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost - return compute_cost - - def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): - ''' - Compute the memory cost per device with this specific strategy. - - Argument: - sharding_size_forward(int): The forward activation will be divided - into sharding_size_forward number partions. - sharding_size_backward_activation(int): The backward activation will - be divided into sharding_size_backward_activation number partions. - sharding_size_weight(int): The backward weight will be divided - into sharding_size_weight number partions. - - Return: - memory_cost(Tuple[float]): Memory cost per device with this - specific strategy, the first element of this tuple is forward - memory cost, and the second element of this tuple is backward - memory cost. - memory_cost_forward(float): Memory cost of forward activation per - device with this specific strategy. - memory_cost_backward_activation(float): Memory cost of backward activation - per device with this specific strategy. - ''' - # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel_output = self.output_data.numel() - numel_input = self.input_data.numel() - numel_weight = self.weight.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - - # forward memory_cost - memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward - memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight - memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight - - # backward memory_cost - memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation - memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight - memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight - - # memory_cost pair - memory_cost = (memory_cost_forward, memory_cost_backward) - - return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight - - @ignore_sharding_exception - def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' - - dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {1: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] - channel_in = self.input_data.shape[1] - channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1] - compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_weight = self.device_mesh.shape[mesh_dim_1] - memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) - - # This strategy do not need to do all_reduce operation during forward - communication_cost_forward = 0 - # compute the backward communication cost to all reduce the input activation grad - communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, - mesh_dim_1) - # compute the backward communication cost to all reduce the weight due to data parallel - communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) - # total communication cost - communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_input_batch(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' - - dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] - channel_in = self.input_data.shape[1] - channel_out = self.weight.shape[1] - compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_0] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_weight = 1 - memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_weight) - - # This strategy do not need to do all_reduce operation in forward phase. - communication_cost_forward = 0 - # compute the backward communication cost to all reduce the weight due to data parallel - communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) - # compute the total cost - communication_cost = communication_cost_forward + communication_cost_backward_weight - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' - - dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] - channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1] - channel_out = self.weight.shape[1] - compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_0] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_weight = self.device_mesh.shape[mesh_dim_1] - memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) - - # compute the communication cost of this strategy during forward phase - communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1) - # This strategy do not need to do all_reduce operation to compute the input activation grad - communication_cost_backward_activation = 0 - # compute the backward communication cost to all reduce the weight due to data parallel - communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) - # compute total cost - communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' - - dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] - channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] - channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1] - compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_1] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) - - # compute the communication cost of this strategy during forward phase - communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) - # compute the communication cost of this strategy during backward phase - communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) - communication_cost = communication_cost_forward + communication_cost_backward - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_input_in_channel_weight_in_channel(self, mesh_dim_0): - name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' - - dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] - channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] - channel_out = self.weight.shape[1] - compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) - - # compute the memory cost of this strategy - sharding_size_forward = 1 - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_weight = self.device_mesh.shape[mesh_dim_0] - memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) - - # compute the communication cost of this strategy during forward phase - communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) - # This strategy do NOT need all_reduce during forward phase - communication_cost_backward = 0 - communication_cost = communication_cost_forward + communication_cost_backward - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_weight_out_channel(self, mesh_dim_0): - name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' - - dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {1: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {1: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] - channel_in = self.input_data.shape[1] - channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0] - compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_0] - sharding_size_backward_activation = 1 - sharding_size_weight = self.device_mesh.shape[mesh_dim_0] - memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) - - # This strategy do not need to do all_reduce during forward phase - communication_cost_forward = 0 - # compute the communication cost of this strategy during backward phase - communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0) - communication_cost = communication_cost_forward + communication_cost_backward - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def non_split(self): - name = f'RR = RR x RR' - - dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] - channel_in = self.input_data.shape[1] - channel_out = self.weight.shape[1] - compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) - - # compute the memory cost of this strategy - sharding_size_forward = 1 - sharding_size_backward_activation = 1 - sharding_size_weight = 1 - memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_weight) - - # This strategy do not need to do all_reduce in both forward and backward phase - communication_cost = 0 - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' - - dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) - channel_in = self.input_data.shape[1] - channel_out = self.weight.shape[1] - compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] - sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ - mesh_dim_1] - sharding_size_weight = 1 - memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_weight) - - # This strategy do not need to do all_reduce in forward phase - communication_cost_forward = 0 - # compute the backward communication cost to all reduce the weight due to data parallel - communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost( - memory_cost_backward_weight, 0) - # compute the total communication cost - communication_cost = communication_cost_backward_weight + communication_cost_forward - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' - - dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - bs = self.input_data.shape[0] - channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] * - self.device_mesh.shape[mesh_dim_1]) - channel_out = self.weight.shape[1] - compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) - - # compute the memory cost of this strategy - sharding_size_forward = 1 - sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ - mesh_dim_1] - sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] - memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) - - # compute communication cost during forward phase - communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost( - memory_cost_forward_activation, 0) - # This strategy do NOT need do all_reduce during backward phase - communication_cost_backward = 0 - communication_cost = communication_cost_forward + communication_cost_backward - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - def register_strategy(self) -> StrategiesVector: - ''' - Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. - - Example: - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() - - model = ConvModel(16, 32) - input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) - # return conv - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - # [x, mul, conv, output] - nodes = [node for node in gm.graph.nodes] - - # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] - strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input) - setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) - - strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ]) - conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2], - device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) - conv_handler.register_strategy_into_strategies_vector() - for strategy in conv_handler.strategies_vector: - print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}') - - Output: - S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} - S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} - S0R = S0R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} - S1R = S1R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} - S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]} - S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]} - RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]} - RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]} - RR = RS0 x S0R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]} - RR = RS1 x S1R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]} - RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} - RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} - RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} - S01R = S01R x RR: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 65538.002, 262148.4, 0, 16385.001, 262148.4, 196614.402]} - RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]} - ''' - # SS = SR x RS - self.split_input_batch_weight_out_channel(0, 1) - self.split_input_batch_weight_out_channel(1, 0) - - # SR = SR x RR - self.split_input_batch(0) - self.split_input_batch(1) - - # SR = SS x SR - self.split_input_both_dim_weight_in_channel(0, 1) - self.split_input_both_dim_weight_in_channel(1, 0) - - # RS = RS x SS - self.split_input_in_channel_weight_both_channel(0, 1) - self.split_input_in_channel_weight_both_channel(1, 0) - - # RR = RS x SR - self.split_input_in_channel_weight_in_channel(0) - self.split_input_in_channel_weight_in_channel(1) - - # RS = RR x RS - self.split_weight_out_channel(0) - self.split_weight_out_channel(1) - - # RR= RR x RR - self.non_split() - - # S01R = S01R x RR - self.split_1d_parallel_on_input_batch(0, 1) - - # RR = RS01 x S01R - self.split_1d_parallel_on_in_channel(0, 1) - - return self.strategies_vector - - -CONV_STRATEGIES_LIST = [ - 'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', - 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', - 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R' -] diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py deleted file mode 100644 index 1f2281cc4..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py +++ /dev/null @@ -1,756 +0,0 @@ -import operator -from enum import Enum -from functools import reduce -from typing import List - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector - -from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP -from .operator_handler import OperatorHandler -from .strategy_generator import IntermediateStrategy, StrategyGenerator - -__all__ = ['DotHandler'] - - -class DotProductStrategyGenerator(StrategyGenerator): - """ - DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation. - This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we - do not consider bias here. - """ - - def validate(self, input, other): - assert input.dim() == 1 and other.dim() == 1 - - def no_split(self): - name = f'R = R dot R' - dim_partition_dict = {"input": {}, "other": {}, "output": {}} - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - def split_one_dim(self, mesh_dim): - name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}' - dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}} - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim]) - - def generate(self) -> List[IntermediateStrategy]: - strategy_list = [] - - # do not split dimensions for dot product - # R = R dot R - strategy_list.append(self.no_split()) - - # split two tensors in the same dimensions - # S = S dot S - strategy_list.append(self.split_one_dim(0)) - strategy_list.append(self.split_one_dim(1)) - - return strategy_list - - -class MatVecStrategyGenerator(StrategyGenerator): - - def validate(self, input, other) -> bool: - assert input.dim() > 1 and other.dim() == 1 - - def no_split(self): - name = "R = R x R" - dim_partition_dict = {"input": {}, "other": {}, "output": {}} - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - def split_input_batch(self, mesh_dim): - name = f'S{mesh_dim}R = S{mesh_dim}R x R' - dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}} - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - def generate(self) -> List[IntermediateStrategy]: - strategy_list = [] - - # no split - strategy_list.append(self.no_split()) - - # split the batch dim for the first tensor only - strategy_list.append(self.split_input_batch(0)) - strategy_list.append(self.split_input_batch(1)) - - return strategy_list - - -class MatMulStrategyGenerator(StrategyGenerator): - """ - MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is - a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm. - - A matmul can be formulated as [n, p] x [p, q] = [n, q] - - Args: - is_linear (bool): whether this generator is used for nn.Linear and F.linear. - This will incur extra transformation of the dim partitioning as the weight is transposed. - """ - - def __init__(self, is_linear: bool, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_linear = is_linear - - # as the weight for the linear module is transposed, we can compute - # the correponding dimension indexfor convenience - if is_linear: - self.dim_q = 0 - self.dim_p = 1 - else: - self.dim_q = 1 - self.dim_p = 0 - - def validate(self, input, other, bias) -> bool: - # make sure the second tensor is a 2D tensor - assert input.dim() > 0 and other.dim() == 2 - - # make sure bias is of the same dimension - if self.is_linear: - assert bias is None or bias.shape[-1] == other.shape[0] - else: - assert bias is None or bias.shape[-1] == other.shape[1] - - def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): - # handle case SS = SR x RS - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' - - dim_partition_dict = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - self.dim_q: [mesh_dim_1] - }, - "bias": { - -1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): - # handle the case SR = SS x SR - name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' - dim_partition_dict = { - "input": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, - "other": { - self.dim_p: [mesh_dim_1] - }, - "bias": {}, - "output": { - 0: [mesh_dim_0] - }, - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1]) - - def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' - dim_partition_dict = { - "input": { - -1: [mesh_dim_0] - }, - "other": { - self.dim_p: [mesh_dim_0], - self.dim_q: [mesh_dim_1] - }, - "bias": { - -1: [mesh_dim_1] - }, - "output": { - -1: [mesh_dim_1] - }, - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - def recompute_split_both_contract(self, mesh_dim): - name = f'RR = RS{mesh_dim} x S{mesh_dim}R' - dim_partition_dict = { - "input": { - -1: [mesh_dim] - }, - "other": { - self.dim_p: [mesh_dim] - }, - "bias": {}, - "output": {}, - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim]) - - def split_rhs_space_only(self, mesh_dim): - name = f'RS{mesh_dim} = RR x RS{mesh_dim}' - dim_partition_dict = { - "input": {}, - "other": { - self.dim_q: [mesh_dim] - }, - "bias": { - -1: [mesh_dim] - }, - "output": { - -1: [mesh_dim] - }, - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim]) - - def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' - dim_partition_dict = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "other": {}, - "bias": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - }, - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' - dim_partition_dict = { - "input": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "other": { - self.dim_p: [mesh_dim_0, mesh_dim_1] - }, - "bias": {}, - "output": {}, - } - return IntermediateStrategy(name=name, - dim_partition_dict=dim_partition_dict, - all_reduce_axis=[mesh_dim_0, mesh_dim_1]) - - def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' - - dim_partition_dict = { - "input": {}, - "other": { - self.dim_q: [mesh_dim_0, mesh_dim_1] - }, - "bias": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "output": { - -1: [mesh_dim_0, mesh_dim_1] - }, - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - -class BatchedMatMulStrategyGenerator(StrategyGenerator): - """ - Generate sharding strategies for the batched matrix multiplication. - - A batched matrix multiplication can be viewed as - [b, i, k] x [b, k, j] -> [b, i, j] - """ - - def __init__(self, is_torch_bmm: bool, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_torch_bmm = is_torch_bmm - - def validate(self, input, other, bias) -> bool: - if self.is_torch_bmm: - assert input.shape == other.shape - assert input.dim() > 2 - assert other.shape[-1] == bias.shape[0] - else: - # TODO: validate these inputs are broadcastable - pass - - def split_one_batch_dim(self): - if 1 in self.device_mesh.mesh_shape: - mesh_dim = self.device_mesh.mesh_shape.index(1) - name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' - dim_partition_dict = { - "input": { - 0: [mesh_dim] - }, - "other": { - 0: [mesh_dim] - }, - "bias": {}, - "output": { - 0: [mesh_dim] - } - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - else: - return None - - def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' - dim_partition_dict = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "bias": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - } - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - def split_one_batch_dim(self, mesh_dim): - name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' - dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' - dim_partition_dict = { - "input": { - 0: [mesh_dim_0], - -2: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0] - }, - "bias": {}, - "output": { - 0: mesh_dim_0, - -2: [mesh_dim_1] - } - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' - dim_partition_dict = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, - "bias": { - -1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - } - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict) - - def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' - dim_partition_dict = { - "input": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0], - -2: [mesh_dim_1] - }, - "bias": {}, - "output": { - 0: [mesh_dim_0], - -2: [mesh_dim_1] - } - } - return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1]) - - def generate(self) -> List[IntermediateStrategy]: - strategy_list = [] - - # split only the batch dimension - # Sb = Sb x Sb - # can be None as it is only for 1D device mesh - strategy = self.split_one_batch_dim() - if strategy: - strategy_list.append(strategy) - - # split batch dim of two inputs and the i dim of the first tensor - # SbSi = SbSi x Sb - strategy_list.append(self.split_batch_dim_lhs_space(0, 1)) - strategy_list.append(self.split_batch_dim_lhs_space(1, 0)) - - # split batch dim of two inputs and the j of the second tensor - # SbSj = Sb x SbSj - strategy_list.append(self.split_batch_dim_rhs_space(0, 1)) - strategy_list.append(self.split_batch_dim_rhs_space(1, 0)) - - # split batch dim of two inputs and the k dim of two inputs - # Sb = SbSk x SbSk, need to all-reduce by k dim - strategy_list.append(self.split_batch_dim_both_contract(0, 1)) - strategy_list.append(self.split_batch_dim_both_contract(1, 0)) - - # split two batch dim - strategy_list.append(self.split_two_batch_dim(0, 1)) - strategy_list.append(self.split_two_batch_dim(1, 0)) - - return strategy_list - - -class DotHandler(OperatorHandler): - """ - A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.input_data = self.predecessor_node[0]._meta_data - self.weight = self.module_named_parameters['weight'] - self.output_data = self.node._meta_data - - def _generate_compute_cost(self, input_shape, weight_shape, total_sharding_size): - # TODO: consider bias addition - compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size - return compute_cost - - @ignore_sharding_exception - def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): - # handle case SS = SR x RS - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' - - dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - # linear layer weight is transposed during init - dim_partition_dict_for_weight = {0: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute computation cost - total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) - - # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) - - # compute the communication cost - communication_cost_activation_backward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1) - communication_cost_weight_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0) - communication_cost = communication_cost_activation_backward + communication_cost_weight_backward - - # create and register strategy - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=toatl_memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): - # handle the case SR = SS x SR - name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' - - dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - # since weight of the linear layer is transposed - # the actual dim to be sharded is 1 - dim_partition_dict_for_weight = {1: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) - - # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) - - # compute the communication cost of this strategy - communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1) - communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0) - communication_cost = communication_cost_activation_forward + communication_cost_grad_backward - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=toatl_memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' - - dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) - - # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) - - # compute the communication cost of this strategy - communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_0) - communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim_1) - communication_cost = communication_cost_activation_backward + communication_cost_activation_forward - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=toatl_memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def recompute_split_both_contract(self, mesh_dim): - name = f'RR = RS{mesh_dim} x S{mesh_dim}R' - - dim_partition_dict_for_input = {1: [mesh_dim]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {1: [mesh_dim]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - total_sharding_size = self.device_mesh.shape[mesh_dim] - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) - - # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) - - # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim) - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=toatl_memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_rhs_space_only(self, mesh_dim): - name = f'RS{mesh_dim} = RR x RS{mesh_dim}' - - dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {1: [mesh_dim]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - total_sharding_size = self.device_mesh.shape[mesh_dim] - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) - - # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) - - # compute the communication cost of this strategy - communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim) - communication_cost = communication_cost_activation_backward - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=toatl_memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' - - dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) - - # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) - - # compute the communication cost of this strategy - communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0) - communication_cost = communication_cost_weight_backward - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=toatl_memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' - - dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) - - # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) - - # compute the communication cost of this strategy - communication_cost_forward_activation = self.device_mesh.flatten_device_mesh.all_reduce_cost( - activation_memory_cost, 0) - communication_cost = communication_cost_forward_activation - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=toatl_memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' - - dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) - - # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) - # compute the communication cost of this strategy - communication_cost_activation_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost( - input_grad_memory_cost, 0) - communication_cost = communication_cost_activation_backward - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=toatl_memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - def register_strategy(self) -> StrategiesVector: - ''' - Generate every possible strategies for a linear node, and record all strategies into the strategies_vector. - - Output: - - ''' - # SS = SR x RS - self.split_lhs_space_rhs_space(0, 1) - self.split_lhs_space_rhs_space(1, 0) - - # SR = SS x SR - self.split_lhs_space_both_contract(0, 1) - self.split_lhs_space_both_contract(1, 0) - - # RS = RS x SS - self.split_rhs_space_both_contract(0, 1) - self.split_rhs_space_both_contract(1, 0) - - # RR= RS x SR - self.recompute_split_both_contract(0) - self.recompute_split_both_contract(1) - - # RS = RR x RS - self.split_rhs_space_only(0) - self.split_rhs_space_only(1) - - # S01R = S01R x RR - self.split_lhs_1st_dim_1d(0, 1) - - # RR = RS01 x S01R - self.split_lhs_2nd_dim_1d(0, 1) - - # RS01 = RR x RS01 - self.split_rhs_2nd_dim_1d(0, 1) - - return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py deleted file mode 100644 index d01a487ad..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py +++ /dev/null @@ -1,179 +0,0 @@ -import operator -import warnings -from copy import deepcopy -from functools import reduce -from typing import Dict, List - -import torch -from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ - ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec - -from .operator_handler import OperatorHandler - -__all__ = ['EmbeddingHandler'] - - -class EmbeddingHandler(OperatorHandler): - """ - An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding). - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.input_data = self.predecessor_node[0]._meta_data - self.weight = self.module_named_parameters['weight'] - self.output_data = self.node._meta_data - - def _generate_compute_cost(self, total_sharding_size): - input_shape = self.input_data.shape - weight_shape = self.weight.shape - input_shape_product = reduce(operator.mul, input_shape, 1) - weight_shape_product = reduce(operator.mul, weight_shape, 1) - compute_cost = input_shape_product * weight_shape_product * 2 / total_sharding_size - return compute_cost - - def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): - ''' - Compute the memory cost per device with this specific strategy. - - Argument: - sharding_size_forward(int): The forward activation will be divided - into sharding_size_forward number partions. - sharding_size_backward_activation(int): The backward activation will - be divided into sharding_size_backward_activation number partions. - sharding_size_weight(int): The backward weight will be divided - into sharding_size_weight number partions. - - Return: - memory_cost(Tuple[float]): Memory cost per device with this - specific strategy, the first element of this tuple is forward - memory cost, and the second element of this tuple is backward - memory cost. - memory_cost_forward(float): Memory cost of forward activation per - device with this specific strategy. - memory_cost_backward_activation(float): Memory cost of backward activation - per device with this specific strategy. - ''' - # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel_output = self.output_data.numel() - numel_input = self.input_data.numel() - numel_weight = self.weight.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - - # forward memory_cost - memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward - memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight - memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight - - # backward memory_cost - memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation - memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight - memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight - - # memory_cost pair - memory_cost = (memory_cost_forward, memory_cost_backward) - - return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight - - @ignore_sharding_exception - def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1): - name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}' - - dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {2: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1] - compute_cost = self._generate_compute_cost(total_sharding_size) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_1] - sharding_size_backward_activation = 1 - sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) - - # compute the communication cost of this strategy during forward phase - communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) - # compute the communication cost of this strategy during backward phase - communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) - communication_cost = communication_cost_forward + communication_cost_backward - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR' - - dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) - - # compute the computation cost of this strategy - total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1] - compute_cost = self._generate_compute_cost(total_sharding_size) - - # compute the memory cost of this strategy - sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_weight = 1 - memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) - - # This strategy do not need to do all_reduce during forward phase - communication_cost_forward = 0 - # compute the communication cost of this strategy during backward phase - communication_cost_backward_activation = 0 - communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost( - memory_cost_backward_weight, 0) - communication_cost_backward = communication_cost_backward_activation + communication_cost_backward_weight - communication_cost = communication_cost_forward + communication_cost_backward - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.append(sharding_strategies) - - def register_strategy(self) -> StrategiesVector: - ''' - Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. - ''' - # RRS = RR x SS - self.split_weight_both_dim(0, 1) - self.split_weight_both_dim(1, 0) - - # SSR = SS x RR - self.split_input_both_dim(0, 1) - self.split_input_both_dim(1, 0) - - return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py deleted file mode 100644 index 8062d0f4b..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py +++ /dev/null @@ -1,241 +0,0 @@ -import operator -from functools import reduce - -import torch - -from colossalai.auto_parallel.tensor_shard.deprecated._utils import ( - enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding, - generate_sharding_size, - ignore_sharding_exception, -) -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector - -from .operator_handler import OperatorHandler - -__all__ = ['LayerNormHandler'] - - -class LayerNormHandler(OperatorHandler): - """ - A OperatorHandler which deals with the sharding strategies of normalization. - - Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.input_data = self.predecessor_node[0]._meta_data - self.weight = self.module_named_parameters['weight'] - self.bias = self.module_named_parameters['bias'] - self.output_data = self.node._meta_data - - def _generate_compute_cost(self, total_sharding_size): - ''' - Compute the computation cost per device with this specific strategy. - - Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. - - Argument: - bs(int): Batch size of the input data. - channel_in(int): The channel dimension of input data. - - Return: - compute_cost(float): Computation cost per device with this specific strategy - ''' - # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. - # TODO: a constant coefficient need to be added. - - norm_kernel_size = self.weight.shape - # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization. - input_batch_shape = self.input_data.shape[:-len(norm_kernel_size)] - input_batch_product = reduce(operator.mul, input_batch_shape, 1) - norm_kernel_product = reduce(operator.mul, norm_kernel_size, 1) - forward_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size - backward_activation_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size - # To compute gradient of on norm kernel element requires input_batch_product times computation, so - # the total cost is input_batch_product * norm_kernel_product - backward_weight_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size - backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost - compute_cost = forward_compute_cost + backward_compute_cost - return compute_cost - - def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): - ''' - Compute the memory cost per device with this specific strategy. - - Argument: - sharding_size_forward(int): The forward activation will be divided - into sharding_size_forward number partions. - sharding_size_backward_activation(int): The backward activation will - be divided into sharding_size_backward_activation number partions. - sharding_size_weight(int): The backward weight will be divided - into sharding_size_weight number partions. - - Return: - memory_cost(Tuple[float]): Memory cost per device with this - specific strategy, the first element of this tuple is forward - memory cost, and the second element of this tuple is backward - memory cost. - memory_cost_forward(float): Memory cost of forward activation per - device with this specific strategy. - memory_cost_backward_activation(float): Memory cost of backward activation - per device with this specific strategy. - ''' - # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel_output = self.output_data.numel() - # this operation will not change the shape of input - numel_input = numel_output - numel_weight = self.weight.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - - # forward memory_cost - memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward - memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight - memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight - - # backward memory_cost - memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation - memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight - memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight - - # memory_cost pair - memory_cost = (memory_cost_forward, memory_cost_backward) - - return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight - - def _generate_strategy_with_dim_partition(self, dim_partition): - dim_partition_dict_for_input = dim_partition - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = dim_partition - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_input.sharding_sequence} x {sharding_spec_for_weight.sharding_sequence}' - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) - - total_sharding_size = generate_sharding_size(dim_partition, self.device_mesh) - # compute the computation cost of this strategy - compute_cost = self._generate_compute_cost(total_sharding_size) - - # compute the memory cost of this strategy - sharding_size_forward = generate_sharding_size(dim_partition_dict_for_input, self.device_mesh) - sharding_size_backward_activation = generate_sharding_size(dim_partition_dict_for_output, self.device_mesh) - sharding_size_weight = generate_sharding_size(dim_partition_dict_for_weight, self.device_mesh) - memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_weight) - - total_mesh_dim_list = [] - for mesh_dim_list in dim_partition.values(): - total_mesh_dim_list.extend(mesh_dim_list) - - # This strategy do not need to do all_reduce operation for activation - communication_cost_forward_activation = 0 - communication_cost_backward_activation = 0 - if len(total_mesh_dim_list) == 1: - communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, - total_mesh_dim_list[0]) - else: - assert len(total_mesh_dim_list) == 2, f'temporally we just support 2d device mesh.' - communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost( - memory_cost_backward_weight, 0) - communication_cost = communication_cost_forward_activation + communication_cost_backward_activation + communication_cost_backward_weight - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - - self.strategies_vector.append(sharding_strategies) - - @ignore_sharding_exception - def split_input_batch_single_mesh_dim(self, mesh_dim_0): - batch_dimension_length = self.input_data.dim() - self.weight.dim() - dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length) - for dim_partition in dim_partition_list: - self._generate_strategy_with_dim_partition(dim_partition) - - @ignore_sharding_exception - def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1): - batch_dimension_length = self.input_data.dim() - self.weight.dim() - dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length) - for dim_partition in dim_partition_list: - self._generate_strategy_with_dim_partition(dim_partition) - - @ignore_sharding_exception - def non_split(self): - name = f'RR = RR x R' - - dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) - - dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) - - dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) - - total_sharding_size = 1 - # compute the computation cost of this strategy - compute_cost = self._generate_compute_cost(total_sharding_size) - - # compute the memory cost of this strategy - sharding_size_forward = 1 - sharding_size_backward_activation = 1 - sharding_size_weight = 1 - memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_weight) - - # This strategy do not need to do all_reduce operation - communication_cost = 0 - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_output, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - - self.strategies_vector.append(sharding_strategies) - - def register_strategy(self) -> StrategiesVector: - ''' - Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. - - Example: - norm_handler = BatchNormHandler(node, strategies_vector, - self.shape_consistency_manager) - norm_handler.register_strategy() - for strategy in norm_handler.strategies_vector: - print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') - - Output: - RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0 - RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0 - RR = RR x R, computation_cost: 262144, memory_cost: 1048576 - RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0 - ''' - - # SR = SR x R with single mesh dim on batch dimensions - self.split_input_batch_single_mesh_dim(0) - self.split_input_batch_single_mesh_dim(1) - - # SR = SR x R with both mesh dims on batch dimensions - self.split_input_batch_both_mesh_dim(0, 1) - - # RR = RR x R - self.non_split() - - return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py deleted file mode 100644 index b120cc16b..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py +++ /dev/null @@ -1,149 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict, List -from webbrowser import Opera - -import torch -import torch.nn as nn -from torch.fx.node import Node - -from colossalai.auto_parallel.tensor_shard.deprecated.constants import * -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec - -from .._utils import generate_resharding_costs, generate_sharding_spec -from ..sharding_strategy import StrategiesVector - -__all__ = ['OperatorHandler'] - - -class OperatorHandler(ABC): - ''' - The OperatorHandler is an abstract class used to generate every possible strategies for an operator node. - - Args: - node (Node): the input node in node argument list. - device_mesh (DeviceMesh): A logical view of a physical mesh. - strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. - handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference. - ''' - - def __init__(self, - node: Node, - device_mesh: DeviceMesh, - strategies_vector: StrategiesVector, - handle_backward: bool = True): - self.node = node - self.predecessor_node = list(node._input_nodes.keys()) - self.successor_node = list(node.users.keys()) - self.device_mesh = device_mesh - self.strategies_vector = strategies_vector - self.handle_backward = handle_backward - - # find the module and its parameters associated with this node - # this can be used to compute the compute/communication/sharding cost - if self.node.op == 'call_module': - module = node.graph.owning_module.get_submodule(node.target) - named_parameters = list(module.named_parameters(recurse=False)) - # convert named parameters from list to dict - named_parameters = {k: v for k, v in named_parameters} - elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP: - module = None - parameters = list(self.node.args)[1] - if isinstance(parameters, Node): - named_parameters = {'weight': parameters._meta_data} - else: - named_parameters = {} - else: - module = None - named_parameters = None - self.module = module - self.module_named_parameters = named_parameters - - @abstractmethod - def register_strategy(self) -> StrategiesVector: - """ - Register - """ - pass - - def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight, - sharding_spec_for_input): - ''' - Compute the memory cost per device with this specific strategy. - - Argument: - dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded, - and the value of the key decribe which logical axis will be sharded in that dimension. - dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded, - and the value of the key decribe which logical axis will be sharded in that dimension. - Return: - total_memory_cost(float): total memory cost per device with this specific strategy - activation_cost(float): the memory cost of activation per device with this specific strategy - weight_memory_cost(float): the memory cost of weight per device with this specific strategy - ''' - # compute the size of one element with specific dtype - dtype = self.input_data.dtype - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - - # compute the memory cost of activation - activation_numel = self.output_data.numel() - output_mesh_dims = [] - for sharding_dim, mesh_dims in dim_partition_dict_for_output.items(): - output_mesh_dims.extend(mesh_dims) - activation_sharding_size = 1 - for mesh_dim in output_mesh_dims: - activation_sharding_size *= self.device_mesh.shape[mesh_dim] - activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes - - # compute the memory cost of weight - weight_numel = self.weight.numel() - weight_sharding_size = 1 - weight_mesh_dims = [] - for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items(): - weight_mesh_dims.extend(mesh_dims) - for mesh_dim in weight_mesh_dims: - weight_sharding_size *= self.device_mesh.shape[mesh_dim] - weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes - - # compute the memory cost of input grad - input_grad_numel = self.input_data.numel() - input_grad_sharding_size = 1 - input_grad_mesh_dims = [] - for sharding_dim, mesh_dims in sharding_spec_for_input.items(): - input_grad_mesh_dims.extend(mesh_dims) - for mesh_dim in input_grad_mesh_dims: - input_grad_sharding_size *= self.device_mesh.shape[mesh_dim] - input_grad_memory_cost = input_grad_numel / input_grad_sharding_size * size_per_elem_bytes - - memory_cost_forward = activation_memory_cost + weight_memory_cost - memory_cost_backward = input_grad_memory_cost + weight_memory_cost - - return (memory_cost_forward, - memory_cost_backward), activation_memory_cost, weight_memory_cost, input_grad_memory_cost - - def _generate_resharding_costs(self, sharding_specs): - # The resharding_cost of weight is counted due to sharing weight cases. - if hasattr(self.node._meta_data, 'dtype'): - dtype = self.node._meta_data.dtype - else: - assert isinstance(self.node._meta_data, - tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected' - dtype = self.node._meta_data[0].dtype - - nodes = self.predecessor_node - return generate_resharding_costs(nodes=nodes, - sharding_specs=sharding_specs, - count_backward=self.handle_backward, - dtype=dtype) - - def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: - return generate_sharding_spec(input_=input_, - device_mesh=self.device_mesh, - dim_partition_dict=dim_partition_dict) - - @abstractmethod - def _generate_compute_cost(self, *args, **kwargs): - """ - Compute the flops involved in the node. - """ - pass diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py deleted file mode 100644 index d4ccc8a9c..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py +++ /dev/null @@ -1,89 +0,0 @@ -import colorsys -import math -import warnings -from copy import deepcopy - -import torch - -from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec - -from ..constants import INFINITY_COST -from .operator_handler import OperatorHandler - - -class ReshapeHandler(OperatorHandler): - """ - An OperatorHandler which deals with the sharding strategies of Reshape Operator, such as torch.reshape, torch.flatten, etc. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.input_data = self.predecessor_node[0]._meta_data - self.output_data = self.node._meta_data - - def _generate_compute_cost(self, *args, **kwargs): - return super()._generate_compute_cost(*args, **kwargs) - - @ignore_sharding_exception - def register_strategy(self): - # TODO: add strategies with more output sharding specs other than only fully replicated. - input_node = self.strategies_vector.predecessor_nodes[0] - # For reshape function, to keep the computing correctness we keep the sharding - # spec of input is fully replicated. In addition, we will keep the output in - # replica status and let the successor node choose the way to resharding the - # output node. Therefore, the different strategies of input node with same - # output sharding spec will generate same strategy for reshape function. - sharding_spec_checklist = [] - for strategy in input_node.strategies_vector: - # It looks a little bit confusing, the input of the processing node - # is the output of the input_node. - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' - if input_sharding_spec in sharding_spec_checklist: - continue - sharding_spec_checklist.append(input_sharding_spec) - dim_partition_dict_for_output = {} - if isinstance(self.output_data, tuple): - dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))] - try: - if isinstance(self.output_data, tuple): - output_sharding_spec = [] - for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output): - output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict)) - else: - output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) - except AssertionError as e: - warnings.warn(f'{e}') - continue - name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED' - # TODO: use meta_info_prop to profile memory cost and compute cost - compute_cost = 0 - # consider node._meta_data is in type of tuple - memory_cost = 0 - - # compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating. - dim_partition_dict_for_replicate_input = {} - replicate_input_sharding_spec = self._generate_sharding_spec(self.input_data, - dim_partition_dict_for_replicate_input) - # shape consistency manager is a singleton class - shape_consistency_manager = ShapeConsistencyManager() - _, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec, - replicate_input_sharding_spec) - communication_cost = communication_cost["total"] - - # generate resharding cost - resharding_costs = self._generate_resharding_costs([input_sharding_spec]) - - # to prevent the resharding happening, set their resharding cost to inf. - resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]] - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[input_sharding_spec]) - self.strategies_vector.append(sharding_strategy) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py deleted file mode 100644 index 4e39fcd8e..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py +++ /dev/null @@ -1,45 +0,0 @@ -from dataclasses import dataclass -from abc import ABC, abstractmethod -from typing import List, Dict -from colossalai.device.device_mesh import DeviceMesh - -__all__ = ['IntermediateStrategy', 'StrategyGenerator'] - - -@dataclass -class IntermediateStrategy: - """ - IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is - to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler. - - Args: - name (str): name of the sharding strategy. - dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping. - all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation. - """ - name: str - dim_partition_dict: Dict[str, Dict[int, List[int]]] - all_reduce_axis: List[int] = None - - -class StrategyGenerator(ABC): - """ - StrategyGenerator is used to generate the same group of sharding strategies. - """ - - def __init__(self, device_mesh: DeviceMesh): - self.device_mesh = device_mesh - - @abstractmethod - def generate(self) -> List[IntermediateStrategy]: - """ - """ - pass - - @abstractmethod - def validate(self, *args, **kwargs) -> bool: - """ - Validate if the operands are of desired shape. - If True, means this generator can be used for the current operation. - """ - pass diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py deleted file mode 100644 index c929d2fad..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py +++ /dev/null @@ -1,88 +0,0 @@ -import math -import operator -import warnings -from copy import deepcopy -from functools import reduce -from typing import Dict, List - -import torch -from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ - ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.constants import \ - INFINITY_COST -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec - -from .operator_handler import OperatorHandler - -__all__ = ['UnaryElementwiseHandler'] - - -class UnaryElementwiseHandler(OperatorHandler): - """ - An OperatorHandler which deals with the sharding strategies of UnaryElementwiseOp. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.node.op == 'call_module': - target = self.node.target - submod = self.node.graph.owning_module.get_submodule(target) - submod_type = type(submod) - if submod_type == torch.nn.Dropout: - print(f'predecessor nodes of dropout node are {self.predecessor_node}') - input_nodes_len = 0 - for check_node in self.predecessor_node: - if isinstance(check_node._meta_data, torch.Tensor): - input_nodes_len += 1 - assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op, node name is {self.node}, node args is {self.node.args}.' - self.input_data = self.predecessor_node[0]._meta_data - self.input_node = self.predecessor_node[0] - self.output_data = self.node._meta_data - - def _generate_compute_cost(self, *args, **kwargs): - return super()._generate_compute_cost(*args, **kwargs) - - @ignore_sharding_exception - def register_strategy(self): - # TODO: integrate element-wise func and module together - # create sharding strategy for element-wise function - - # For element-wise function, we keep the sharding spec of output node same as - # the input. Therefore, the different strategies of input node with same - # output sharding spec will generate same strategy for element-wise function. - - for index, strategy in enumerate(self.input_node.strategies_vector): - # It looks a little bit confusing, the input of the processing node - # is the output of the input_node. - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' - - dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) - try: - output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict) - except AssertionError as e: - warnings.warn(f'{e}') - continue - # add index into name to pass the duplicated check - # we keep same strategies with different name for node merging, and it will not increase the searching space, - # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. - name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}' - # TODO: use meta_info_prop to profile memory cost and compute cost - compute_cost = self.output_data.numel() - memory_cost = 0 - - resharding_costs = self._generate_resharding_costs([input_sharding_spec]) - - # to prevent the resharding happening, set their resharding cost to inf. - resharding_costs[self.input_node] = [ - 0 if cost == 0 else INFINITY_COST for cost in resharding_costs[self.input_node] - ] - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[input_sharding_spec]) - self.strategies_vector.append(sharding_strategy) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py deleted file mode 100644 index 6991e913d..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py +++ /dev/null @@ -1,186 +0,0 @@ -import operator -import warnings -from copy import deepcopy -from functools import reduce -from typing import Dict, List - -import torch - -from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding, - ignore_sharding_exception) -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec - -from .operator_handler import OperatorHandler - -__all__ = ['WhereHandler'] - - -class WhereHandler(OperatorHandler): - """ - An OperatorHandler which deals with the sharding strategies of torch.where. - """ - - def __init__(self, *args, **kwargs): - # TODO: x or y could be scalar - super().__init__(*args, **kwargs) - assert len(self.predecessor_node) == 3 - self.condition_data = self.predecessor_node[0]._meta_data - self.x_data = self.predecessor_node[1]._meta_data - self.y_data = self.predecessor_node[2]._meta_data - self.condition = self.predecessor_node[0] - self.x = self.predecessor_node[1] - self.y = self.predecessor_node[2] - self.output_data = self.node._meta_data - - def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: - shape = list(input_.shape) - - # padding the shape to the same length as output_data - while len(shape) < self.output_data.dim(): - shape.insert(0, 1) - shape = torch.Size(shape) - - # if the sharding happens on a size one dimension, we should record it as R. - processed_dim_partition_dict = deepcopy(dim_partition_dict) - for dim_index, _ in dim_partition_dict.items(): - if shape[dim_index] == 1: - processed_dim_partition_dict.pop(dim_index) - for dim_index, sharding_index_list in processed_dim_partition_dict.items(): - sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] - sharding_size = reduce(operator.mul, sharding_list, 1) - assert shape[ - dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=shape, - dim_partition_dict=processed_dim_partition_dict) - - return sharding_spec - - def _generate_compute_cost(self, total_sharding_size): - lhs_matrix_shape = self.lhs_data.shape[-2:] - rhs_matrix_shape = self.rhs_data.shape[-2:] - batch_dimensions_shape = self.output_data.shape[:-2] - batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1) - compute_cost = reduce( - operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size - return compute_cost - - def _generate_resharding_costs(self, sharding_specs): - # The resharding_cost of weight is counted due to sharing weight cases. - dtype = self.node._meta_data.dtype - nodes = self.predecessor_node - resharding_costs = {} - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - - # shape consistency manager is a singleton class - shape_consistency_manager = ShapeConsistencyManager() - - for input_node, input_spec in zip(nodes, sharding_specs): - resharding_costs[input_node] = [] - for strategy in input_node.strategies_vector: - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' - # if the input shape is smaller than the target input, we will fill the input to the same length as target. - # Then, use the padded input sharding spec to compute the resharding cost. - if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape): - new_entire_shape = list(input_sharding_spec.entire_shape) - while len(new_entire_shape) < len(input_spec.entire_shape): - new_entire_shape.insert(0, 1) - new_entire_shape = torch.Size(new_entire_shape) - new_device_mesh = input_sharding_spec.device_mesh - new_dim_partition_dict = input_sharding_spec.dim_partition_dict - input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh, - entire_shape=new_entire_shape, - dim_partition_dict=new_dim_partition_dict) - - # compute the resharding cost - _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( - input_sharding_spec, input_spec) - total_resharding_cost = total_resharding_cost['total'] - # we need multiply the size of elem dtype to get correct communication cost - resharding_cost = total_resharding_cost * size_per_elem_bytes - resharding_costs[input_node].append(resharding_cost) - - return resharding_costs - - def _convert_partition_dict_to_sharding_spec(self, dim_partition_list): - - sharding_spec_list = [] - check_duplicated_list = [] - for output_dim_partition_dict in dim_partition_list: - try: - output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict) - except AssertionError as e: - warnings.warn(f'{e}') - break - sharding_seq = output_sharding_spec.sharding_sequence - if sharding_seq not in check_duplicated_list: - check_duplicated_list.append(sharding_seq) - sharding_spec_list.append(output_sharding_spec) - - return sharding_spec_list - - def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): - # use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity. - - output_dim_partition_list = [] - dim_size = self.output_data.dim() - # enumerate all the 2D sharding cases - sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) - output_dim_partition_list.extend(sharding_list_2d) - - # enumerate all the 1D sharding cases - sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) - output_dim_partition_list.extend(sharding_list_1d_on_dim_0) - sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) - output_dim_partition_list.extend(sharding_list_1d_on_dim_1) - - # add empty dict for fully replicated case - output_dim_partition_list.append({}) - output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list) - - return output_sharding_spec_list - - @ignore_sharding_exception - def _register_strategy(self, output_sharding_spec): - dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict - sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input) - sharding_spec_for_x = self._generate_sharding_spec(self.x_data, dim_partition_dict_for_input) - sharding_spec_for_y = self._generate_sharding_spec(self.y_data, dim_partition_dict_for_input) - - name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_condition.sharding_sequence} x {sharding_spec_for_x.sharding_sequence} x {sharding_spec_for_y.sharding_sequence}' - dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict - - # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs( - [sharding_spec_for_condition, sharding_spec_for_x, sharding_spec_for_y]) - - # compute the computation cost of this strategy - sharding_dims = [] - for mesh_dims in dim_partition_dict_for_output.values(): - for mesh_dim in mesh_dims: - sharding_dims.append(self.device_mesh.shape[mesh_dim]) - sharding_size = reduce(operator.mul, sharding_dims, 1) - memory_cost = self.output_data.numel() / sharding_size - compute_cost = memory_cost - communication_cost = 0 - - sharding_strategies = ShardingStrategy(name, - output_sharding_spec=output_sharding_spec, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=(sharding_spec_for_condition, sharding_spec_for_x, - sharding_spec_for_y)) - - self.strategies_vector.append(sharding_strategies) - - def register_strategy(self) -> StrategiesVector: - MESH_DIM_LIST = [0, 1] - output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1]) - for output_sharding_spec in output_sharding_specs: - self._register_strategy(output_sharding_spec) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/options.py b/colossalai/auto_parallel/tensor_shard/deprecated/options.py deleted file mode 100644 index 2d34f5c64..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/options.py +++ /dev/null @@ -1,11 +0,0 @@ -from dataclasses import dataclass - -__all__ = ['SolverOptions'] - - -@dataclass -class SolverOptions: - """ - SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. - """ - fast: bool = False diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py deleted file mode 100644 index d468c858e..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py +++ /dev/null @@ -1,91 +0,0 @@ -from copy import deepcopy -from dataclasses import dataclass -from abc import ABC, abstractmethod -from enum import Enum -import operator -import torch -from functools import reduce - -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec -from typing import Dict, List, Union, Tuple, Any -from torch.fx.node import Node -from .constants import * - -__all__ = ['ShardingStrategy', 'StrategiesVector'] - - -@dataclass -class ShardingStrategy: - ''' - ShardingStrategy is a structure containing sharding strategies of inputs and output of this node - and costs information using in solver. - - Argument: - name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'. - output_sharding_spec(ShardingSpec): ShardingSpec of the output node. - compute_cost(float): Computation cost to complete this strategy.(default to 0) - communication_cost(float): Communication cost to complete this strategy.(default to 0) - memory_cost(float): Memory cost of the output node using this strategy.(default to 0) - resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list - with j-th strategy in its strategies_vector transforms to sharding spec wanted in this - strategy.(default to None) - input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes. - ''' - - name: str - # TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor. - output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]] - compute_cost: float = 0. - communication_cost: float = 0. - memory_cost: float = 0. - resharding_costs: Dict[Node, List[float]] = None - # sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input. - # Therefore, we could process them at the specific op(operator.getitem) - input_shardings: List[ShardingSpec] = None - - -class StrategiesVector(list): - ''' - Each node in fx graph will have a corresponding StrategiesVector, to store all the possible - strategies of the node. - - Argument: - node (Node): node for which the list of sharding strategies are generated. - ''' - - def __init__(self, node: Node): - super().__init__() - self.node = node - # fetch its input and output nodes - # TODO: placeholder input nodes - self.predecessor_nodes = list(node._input_nodes.keys()) - if self.node.op == 'output': - self.predecessor_nodes = list(node._input_nodes.keys())[:1] - self.successor_nodes = list(node.users.keys()) - - def check_merge(self): - merge_label = False - if self.node.op == 'call_module': - target = self.node.target - root_module = self.node.graph.owning_module - submod = root_module.get_submodule(target) - submod_type = type(submod) - # merge elementwise module node into source nodes - # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. - if submod_type in ELEMENTWISE_MODULE_OP: - merge_label = True - - if self.node.op == 'call_function': - # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. - if self.node.target in ELEMENTWISE_FUNC_OP: - merge_label = True - # we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case. - if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1: - merge_label = True - # we could merge reshape op, because the output sharding spec of reshape op is always fully replicated. - if self.node.target in RESHAPE_FUNC_OP: - merge_label = True - - return merge_label diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/solver.py b/colossalai/auto_parallel/tensor_shard/deprecated/solver.py deleted file mode 100644 index 4c1d2f3be..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/solver.py +++ /dev/null @@ -1,469 +0,0 @@ -import multiprocessing -import time -import warnings -from typing import Dict - -import numpy as np -from torch.fx.graph import Graph -from torch.fx.node import Node - -from .constants import INFINITY_COST -from .cost_graph import CostGraph -from .graph_analysis import GraphAnalyser -from .strategies_constructor import StrategiesConstructor - -try: - import pulp - from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum -except: - warnings.warn(f'please install the pulp') - -__all___ = ['Solver'] - - -class Solver: - - def __init__(self, - graph: Graph, - strategies_constructor: StrategiesConstructor, - cost_graph: CostGraph, - graph_analyser: GraphAnalyser, - memory_budget: float = -1.0, - solution_numbers: int = 1, - memory_increasing_coefficient: float = 1.3): - ''' - Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. - - Argument: - graph: The computing graph to be optimized. - strategies_constructor: It will provide all the possible strategies for each node in the computing graph. - cost_graph: A graph data structure to simplify the edge cost graph. - graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. - memory_budget: Memory constraint for the solution. - solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. - memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. - ''' - self.graph = graph - self.strategies_constructor = strategies_constructor - self.cost_graph = cost_graph - self.graph_analyser = graph_analyser - self.leaf_strategies = self.strategies_constructor.leaf_strategies - self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] - self.strategy_map = self.strategies_constructor.strategy_map - self.memory_budget = memory_budget - self.solution_numbers = solution_numbers - if self.solution_numbers > 1: - self.memory_increasing_coefficient = memory_increasing_coefficient - else: - self.memory_increasing_coefficient = 1 - self.liveness_list = self.graph_analyser.liveness_analysis() - self.node_index_dict = self._generate_node_index_dict() - # The last solution vector of auto sharding. - self.last_s_val = None - # The last objective value of the best ILP solution. - self.last_objective = None - - def _recover_merged_node_strategy(self): - ''' - During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. - Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged - node. - ''' - for node_index, node in enumerate(self.nodes): - if node.strategies_vector.check_merge(): - # the merged node has only one input, and its strategies follow the input sharding strategy - input_strategies_vector = node.args[0].strategies_vector - input_best_strategy_index = self.last_s_val[node_index - 1] - input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec - for strategy_index, strategy in enumerate(node.strategies_vector): - if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence: - self.last_s_val[node_index] = strategy_index - break - - def _generate_node_index_dict(self) -> Dict[Node, int]: - node_index_dict = {} - for index, strategies_vector in enumerate(self.leaf_strategies): - node_index_dict[strategies_vector.node] = index - return node_index_dict - - def _prepare_data_for_solver(self): - ''' - Extract information from components for solver. - ''' - node_nums = len(self.leaf_strategies) - memory_budget = self.memory_budget - - # prepare strategies_len - strategies_len = [] - for node in self.nodes: - strategies_len.append(self.cost_graph.node_lens[node]) - strategies_len = np.array(strategies_len) - - # prepare following_nodes - following_nodes = self.cost_graph.following_dict - index_following_nodes = {} - for src, target in following_nodes.items(): - src_index = self.node_index_dict[src] - target_index = self.node_index_dict[target] - index_following_nodes[src_index] = target_index - following_nodes = index_following_nodes - for index in range(node_nums): - if index not in following_nodes: - following_nodes[index] = -1 - - # prepare edge_pairs and resharding costs - edge_pairs = [] - resharding_costs = [] - for pairs, edge_cost in self.cost_graph.edge_costs.items(): - src_node = pairs[0] - dst_node = pairs[1] - src_node_index = self.node_index_dict[src_node] - dst_node_index = self.node_index_dict[dst_node] - edge_pairs.append(src_node_index) - edge_pairs.append(dst_node_index) - - for i in range(strategies_len[src_node_index]): - for j in range(strategies_len[dst_node_index]): - resharding_costs.append(edge_cost[(i, j)]) - edge_pairs = np.array(edge_pairs) - resharding_costs = np.array(resharding_costs) - - # prepare liveness_set - liveness_set = self.liveness_list - - # omit alias_set now - alias_set = None - alias_convert_costs = None - - # prepare compute_costs, communication_costs and memory_costs - compute_costs = [] - communication_costs = [] - memory_costs = [] - extra_node_costs = self.cost_graph.extra_node_costs - for strategies_vector in self.leaf_strategies: - node = strategies_vector.node - for index, strategy in enumerate(strategies_vector): - compute_costs.append(strategy.compute_cost) - # node in extra_node_costs means it has some extra communication - # cost from node merging, so we need to add those extra communication - # cost into - if node in extra_node_costs: - origin_communication_cost = strategy.communication_cost - extra_node_cost = extra_node_costs[node][index] - communication_cost = origin_communication_cost + extra_node_cost - communication_costs.append(communication_cost) - else: - communication_costs.append(strategy.communication_cost) - # temporarily we just consider the forward memory cost - memory_cost = strategy.memory_cost - if isinstance(memory_cost, tuple): - memory_costs.append(memory_cost[0]) - else: - memory_costs.append(memory_cost) - compute_costs = np.array(compute_costs) - communication_costs = np.array(communication_costs) - memory_costs = np.array(memory_costs) - - # omit initial value for nodes - s_init_np = None - - return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np - - def _call_solver_serialized_args(self, - node_nums, - memory_budget, - strategies_len, - following_nodes, - edge_pairs, - alias_set, - liveness_set, - compute_costs, - communication_costs, - memory_costs, - resharding_costs, - alias_convert_costs, - s_init_np=None): - """ - Call the solver with serialized arguments. - """ - - tic = time.time() - - for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]: - assert isinstance(x, np.ndarray) - assert len(strategies_len) == node_nums, "strategies_len" - - def get_non_zero_index(binary_vector): - """ - Get the index of non-zero item in a vector. - """ - ct = 0 - ret = None - for i, elem in enumerate(binary_vector): - if pulp.value(elem): - ret = i - ct += 1 - - assert ct == 1 - return ret - - # 0. Unpack flatten numpy arrays - s_follow = following_nodes - - E = edge_pairs.reshape((-1, 2)) # noqa - r = [] - pt = 0 - edge_set = set() - for (i, j) in E: - prod_length = strategies_len[i] * strategies_len[j] - - if (i, j) in edge_set: - raise ValueError(f"Duplicated edges: {(i, j)}") - - edge_set.add((i, j)) - r.append(resharding_costs[pt:pt + prod_length]) - pt += prod_length - assert pt == len(resharding_costs) - - ###################### - # omit alias set now # - ###################### - - # A = alias_set.reshape((-1, 2)) # noqa - # for (i, j) in A: - # prod_length = strategies_len[i] * strategies_len[j] - # v.append(alias_convert_costs[pt:pt + prod_length]) - # pt += prod_length - # assert pt == len(alias_convert_costs) - - # L = [] # noqa - # pt = node_nums - # for i in range(node_nums): - # length = liveness_set[i] - # L.append(liveness_set[pt:pt + length]) - # pt += length - # assert pt == len(liveness_set) - v = [] - pt = 0 - - c = [] - d = [] - m = [] - pt = 0 - for i in range(node_nums): - length = strategies_len[i] - c.append(compute_costs[pt:pt + length]) - d.append(communication_costs[pt:pt + length]) - m.append(memory_costs[pt:pt + length]) - pt += length - assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" - assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" - assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}" - - # 1. Create variables - - ############################# - # create variables for node # - ############################# - s = [] - num_nodes = 0 - reverse_follow_backpatch = [] - for i in range(node_nums): - if s_follow[i] < 0: - if strategies_len[i] == 1: - s.append([1]) - else: - num_nodes += 1 - s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) - else: - if s_follow[i] < len(s): - s.append(s[s_follow[i]]) - else: - s.append(None) - reverse_follow_backpatch.append(i) - - for i in reverse_follow_backpatch: - s[i] = s[s_follow[i]] - - ############################# - # create variables for edge # - ############################# - e = [] - num_edges = 0 - for (idx, (i, j)) in enumerate(E): - if len(s[i]) == 1: - e.append(s[j]) - elif len(s[j]) == 1: - e.append(s[i]) - else: - num_edges += 1 - e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) - assert len(e[idx]) == len(r[idx]) - for element in s: - assert len(element) > 0 - # 2. Set initial value - ###################################### - # set a initial value for warm start # - ###################################### - if s_init_np is not None: - s_init = s_init_np.reshape((-1, 3)) - for (idx, value, fix) in s_init: - for i in range(len(s[idx])): - s[idx][i].setInitialValue(i == value) - if fix: - s[idx][i].fixValue() - - # 3. Objective - prob = LpProblem("myProblem", LpMinimize) - ################################################################### - # computing the node cost(computing cost and communication cost) # - ################################################################### - obj = 0 - for i in range(node_nums): - assert len(s[i]) == len(c[i]) - assert len(s[i]) == len(d[i]) - - obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) - - ############################################# - # computing the edge cost(resharding cost) # - ############################################# - for i in range(len(E)): - assert len(e[i]) == len(r[i]) - obj += lpDot(e[i], r[i]) - - prob += obj - - # 4. Constraints - # (a). specified by `cat="Binary"` - - # (b) - ################################################# - # make sure each node only choose one strategy # - ################################################# - for i in range(node_nums): - if s_follow[i] < 0: - prob += lpSum(s[i]) == 1 - - # (c) - ################################################# - # compute memory consumption with liveness set # - ################################################# - if memory_budget > 0: - for liveness_stage in liveness_set: - mem = 0 - for live_variable in liveness_stage.unique_live_vars: - node_index = self.node_index_dict[live_variable.node] - mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) - prob += mem <= memory_budget - - # (d). specified by `cat="Binary"` - - for (idx, (i, j)) in enumerate(E): - if strategies_len[i] == 1 or strategies_len[j] == 1: - continue - - # (e) - prob += lpSum(e[idx]) == 1 - - # (f) - for row in range(len(s[i])): - C = len(s[j]) # noqa - prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] - - # (g) - for col in range(len(s[j])): - R = len(s[i]) # noqa - C = len(s[j]) # noqa - prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] - - # (h) - ###################### - # omit alias set now # - ###################### - - # alias_set = set() - # for (idx, (i, j)) in enumerate(A): - # R = len(s[i]) # noqa - # C = len(s[j]) # noqa - # if (i, j) in alias_set: - # raise ValueError(f"Duplicated edges: {(i, j)}") - - # alias_set.add((i, j)) - # alias_set.add((j, i)) - - # for row in range(len(s[i])): - # for col in range(len(s[j])): - # if v[idx][row * C + col] > 0.5: - # prob += s[i][row] + s[j][col] <= 1 - - verbose = True - - msg = verbose - time_limit = 600 - assert "COIN_CMD" in pulp.listSolvers( - onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") - - solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) - # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) - prob.solve(solver) - - status = prob.status - objective = pulp.value(prob.objective) - objective = float(objective) if objective is not None else -1.0 - if verbose: - print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" - f"Time: {time.time() - tic}") - print(f"#nodes: {num_nodes}, #edges: {num_edges}") - - if prob.status in [pulp.LpStatusInfeasible]: - raise RuntimeError("Cannot run the function under the given memory budget. " - "Please increase the memory budget.") - - # Get and check results - s_val = np.full((node_nums,), -1, dtype=np.int32) - for i in range(node_nums): - s_val[i] = get_non_zero_index(s[i]) - - e_val = np.full((len(E),), -1, dtype=np.int32) - for (idx, (i, j)) in enumerate(E): - e_val[idx] = get_non_zero_index(e[idx]) - i_spec_index = e_val[idx] // len(s[j]) - j_spec_index = e_val[idx] % len(s[j]) - assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" - assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" - if verbose and r[idx][e_val[idx]] > 0: - print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") - - self.last_s_val = list(s_val) - self._recover_merged_node_strategy() - self.last_objective = objective - - if objective > INFINITY_COST: - warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") - - return self.last_s_val, e_val, self.last_objective, status - - def call_solver_serialized_args(self): - """ - Call the solver with serialized arguments and handle python errors. Additionally, - we could give a serious of solutions with different memory budget. - """ - if self.solution_numbers == 1: - args = self._prepare_data_for_solver() - ret = self._call_solver_serialized_args(*args) - - return ret - - origin_memory_budget = self.memory_budget - memory_budget_list = [ - origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers) - ] - ret_list = [] - for memory_budget in memory_budget_list: - self.memory_budget = memory_budget - args = self._prepare_data_for_solver() - ret = self._call_solver_serialized_args(*args) - ret_list.append(ret) - - return ret_list diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py deleted file mode 100644 index 7bebde9d6..000000000 --- a/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py +++ /dev/null @@ -1,426 +0,0 @@ -import builtins -import math -import operator -from copy import deepcopy -from typing import Dict, List - -import torch -from torch.fx import Graph, Node - -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec - -from ._utils import generate_resharding_costs, generate_sharding_spec -from .constants import * -from .op_handler import * -from .options import SolverOptions -from .sharding_strategy import ShardingStrategy, StrategiesVector - -__all__ = ['StrategiesConstructor'] - - -class StrategiesConstructor: - """ - StrategiesConstructor is used to construct the parallelization plan for the model execution. - - Args: - graph (Graph): a Graph object used for analysis and strategy generation. - device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. - solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. - """ - - def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): - self.graph = graph - assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' - self.root_module = self.graph.owning_module - self.nodes = list(graph.nodes) - self.device_mesh = device_mesh - self.leaf_strategies = [] - self.strategy_map = {} - self.solver_options = solver_options - - def remove_duplicated_strategy(self, strategies_vector): - ''' - In build_strategies_and_cost method, we may produce some duplicated strategies. - In this method, we will remove the duplicated strategies depending on the strategies name. - ''' - name_checklist = [] - remove_list = [] - for strategy in strategies_vector: - if strategy.name not in name_checklist: - name_checklist.append(strategy.name) - else: - remove_list.append(strategy) - - for strategy in remove_list: - strategies_vector.remove(strategy) - - def _is_bcast_matmul(self, node): - is_bcast_matmul = False - if node.target is torch.matmul and len(node.args) == 2: - lhs_data = node.args[0]._meta_data - rhs_data = node.args[1]._meta_data - if lhs_data.dim() >= 3 and rhs_data.dim() >= 3: - is_bcast_matmul = True - return is_bcast_matmul - - def build_strategies_and_cost(self): - for node in self.nodes: - strategies_vector = StrategiesVector(node) - input_nodes_len = 0 - for check_node in strategies_vector.predecessor_nodes: - if isinstance(check_node._meta_data, torch.Tensor): - input_nodes_len += 1 - # input_nodes_len = len(strategies_vector.predecessor_nodes) - # placeholder node - if node.op == 'placeholder': - # For placeholder nodes, if solver_options.fast is True, we just let them in - # fully replicate status, then strategies of following node will be treated equally due - # to replicate status has no resharding cost to other status. At the same time, the searching - # space is smaller than enumerating all the possible sharding spec for the placeholder node. - # Otherwise, all the possible sharding spec for the placeholder node will be enumerated. - - if self.solver_options.fast: - # create sharding strategy for placeholder - name = 'Replica Placeholder' - dim_partition_dict = {} - output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) - # TODO: use meta_info_prop to profile memory cost - memory_cost = 0 - sharding_strategy_placeholder = ShardingStrategy(name, - output_sharding_spec, - memory_cost=memory_cost) - strategies_vector.append(sharding_strategy_placeholder) - - # get_attr node - if node.op == 'get_attr': - # Same as placeholder nodes, if solver_options.fast is True, we just let them in - # fully replicate status, then strategies of following node will be treated equally due - # to replicate status has no resharding cost to other status. At the same time, the searching - # space is smaller than enumerating all the possible sharding spec for the get_attr node. - # Otherwise, all the possible sharding spec for the get_attr node will be enumerated. - if self.solver_options.fast: - # create sharding strategy for get_attr - name = 'Replica Attribute' - dim_partition_dict = {} - output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) - # TODO: use meta_info_prop to profile memory cost - memory_cost = 0 - sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost) - strategies_vector.append(sharding_strategy_attribute) - - # call_module node - if node.op == 'call_module': - - target = node.target - submod = self.root_module.get_submodule(target) - submod_type = type(submod) - - # conv module - if submod_type in CONV_MODULE_OP: - # use ConvHandler to create sharding strategies for conv module node - conv_handler = ConvHandler(node, self.device_mesh, strategies_vector) - conv_handler.register_strategy() - - # linear module - elif submod_type in LINEAR_MODULE_OP: - # use DotHandler to create sharding strategies for linear module node - dot_handler = DotHandler(node, self.device_mesh, strategies_vector) - dot_handler.register_strategy() - - # element-wise module - elif submod_type in ELEMENTWISE_MODULE_OP: - unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) - unary_elementwise_handler.register_strategy() - - # BatchNormNd module - elif submod_type in BATCHNORM_MODULE_OP: - # create sharding strategy for element-wise module - norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector) - norm_handler.register_strategy() - # for strategy in norm_handler.strategies_vector: - # print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') - # assert False - - # MaxPool module - elif submod_type in POOL_MODULE_OP: - # TODO: add sharding constraints on image dimension - # e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension - - # create sharding strategy for element-wise module - assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.' - input_node = strategies_vector.predecessor_nodes[0] - # For element-wise module, we keep the sharding spec of output node same as - # the input. Therefore, the different strategies of input node with same - # output sharding spec will generate same strategy for element-wise module. - sharding_spec_checklist = [] - for strategy in input_node.strategies_vector: - # It looks a little bit confusing, the input of the processing node - # is the output of the input_node. - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, - ShardingSpec), f'The input node should NOT be a tuple of tensor.' - if input_sharding_spec in sharding_spec_checklist: - continue - - sharding_spec_checklist.append(input_sharding_spec) - dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) - output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) - - name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' - - # TODO: use meta_info_prop to profile memory cost and compute cost - compute_cost = node._meta_data.numel() - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) - - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[input_sharding_spec]) - strategies_vector.append(sharding_strategy) - - # embedding module - elif submod_type in EMBEDDING_MODULE_OP: - embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector) - embedding_handler.register_strategy() - - # layernorm module - elif submod_type in LAYERNORM_MODULE_OP: - layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector) - layernorm_handler.register_strategy() - # other module - else: - raise RuntimeError(f'{submod_type} module is NOT supported now.') - - # call_function node - if node.op == 'call_function': - target = node.target - # conv function - if target in CONV_FUNC_OP: - # use ConvHandler to create sharding strategies for conv node - # TODO: the operator_handler does NOT support function node processing now. - conv_handler = ConvHandler(node, self.device_mesh, strategies_vector) - conv_handler.register_strategy() - - # linear function - elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node): - # use DotHandler to create sharding strategies for linear node - # TODO: the operator_handler does NOT support function node processing now. - linear_handler = DotHandler(node, self.device_mesh, strategies_vector) - linear_handler.register_strategy() - - # where function - elif target == torch.where: - if input_nodes_len == 1: - # both of x and y are scalar - pass - - elif input_nodes_len == 2: - # one of x or y is type of scalar - pass - - else: - # general case - where_handler = WhereHandler(node, self.device_mesh, strategies_vector) - where_handler.register_strategy() - - # reshape function - elif target in RESHAPE_FUNC_OP: - # use ReshapeHandler to create sharding strategies for rehsape node - reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) - reshape_handler.register_strategy() - - # element-wise function - elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1): - unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) - unary_elementwise_handler.register_strategy() - - # bcast op - elif target in BCAST_FUNC_OP: - if isinstance(node._meta_data, torch.Tensor): - bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector) - bcast_op_handler.register_strategy() - - # torch.var_mean - elif target == torch.var_mean: - dim = node.kwargs['dim'] - input_tensor_node = strategies_vector.predecessor_nodes[0] - for strategy in input_tensor_node.strategies_vector: - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, - ShardingSpec), f'The input node should NOT be a tuple of tensor.' - entire_shape_input = input_sharding_spec.entire_shape - dim_partition_dict_input = input_sharding_spec.dim_partition_dict - name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})' - if dim in dim_partition_dict_input: - # We need to make the action dimension in replicate status - dim_partition_dict_for_input = deepcopy(dim_partition_dict_input) - dim_partition_dict_for_input.pop(dim) - new_input_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_input, - dim_partition_dict=dim_partition_dict_for_input) - entire_shape_output = deepcopy(entire_shape_input) - entire_shape_output.pop(dim) - dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input) - output_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_output, - dim_partition_dict=dim_partition_dict_for_input) - # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. - compute_cost = 0 - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [new_input_sharding_spec]) - sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[new_input_sharding_spec]) - - else: - entire_shape_output = deepcopy(entire_shape_input) - entire_shape_output.pop(dim) - dim_partition_dict_for_output = deepcopy(dim_partition_dict_input) - output_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_output, - dim_partion_dict=dim_partition_dict_input) - # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. - compute_cost = 0 - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) - sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[input_sharding_spec]) - - strategies_vector.append(sharding_strategy) - - # operator.getitem - elif target == operator.getitem: - index = node.args[1] - input_tensor_node = strategies_vector.predecessor_nodes[0] - for strategy in input_tensor_node.strategies_vector: - if isinstance(strategy.output_sharding_spec, ShardingSpec): - input_sharding_spec = strategy.output_sharding_spec - else: - input_sharding_spec = strategy.output_sharding_spec[index] - assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.' - dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict) - entire_shape_output = deepcopy(input_sharding_spec.entire_shape) - output_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_output, - dim_partition_dict=dim_partition_dict_for_output) - # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. - compute_cost = 0 - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec], - index=index) - # to prevent the resharding happening, set their resharding cost to inf. - resharding_costs[input_tensor_node] = [ - cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node] - ] - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[strategy.output_sharding_spec]) - strategies_vector.append(sharding_strategy) - - # torch.arange function - elif target == torch.arange: - name = f'FULLY REPLICATED ARANGE' - entire_shape_output = node._meta_data.shape - dim_partition_dict_for_output = {} - output_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_output, - dim_partition_dict=dim_partition_dict_for_output) - memory_cost = node._meta_data.numel() - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=0, - memory_cost=memory_cost) - strategies_vector.append(sharding_strategy) - - # op list to be processed to support gpt2 - elif target in (builtins.getattr, operator.le, torch.addmm): - pass - # other function - else: - raise RuntimeError(f'{target} function is NOT supported now.') - - # call_method node - if node.op == 'call_method': - method = getattr(node.args[0]._meta_data.__class__, node.target) - if method in (torch.Tensor.size,): - pass - elif method in ELEMENTWISE_METHOD_OP: - unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) - unary_elementwise_handler.register_strategy() - - elif method in RESHAPE_METHOD_OP: - reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) - reshape_handler.register_strategy() - # print(strategies_vector) - # if len(strategies_vector) == 0: - # print(node) - # assert False - else: - raise RuntimeError(f'{method} function is NOT supported now.') - - # output node - if node.op == 'output': - if self.solver_options.fast: - # create sharding strategy for output - name = 'Replica Output' - input_nodes = strategies_vector.predecessor_nodes - input_sharding_specs = [] - for input_node in input_nodes: - dim_partition_dict_for_input = {} - entire_shape = input_node._meta_data.shape - sharding_spec = ShardingSpec(self.device_mesh, - entire_shape, - dim_partition_dict=dim_partition_dict_for_input) - input_sharding_specs.append(sharding_spec) - - dim_partition_dict = {} - output_sharding_spec = input_sharding_specs - # TODO: use meta_info_prop to profile memory cost - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - input_sharding_specs) - - # clear the resharding cost for the output node - # TODO: we may remove this in final version - for prev_node, resharding_cost_list in resharding_costs.items(): - resharding_costs[prev_node] = [0] * len(resharding_cost_list) - - sharding_strategy_attribute = ShardingStrategy(name, - output_sharding_spec, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=tuple(input_sharding_specs)) - strategies_vector.append(sharding_strategy_attribute) - - self.remove_duplicated_strategy(strategies_vector) - setattr(node, 'strategies_vector', strategies_vector) - self.leaf_strategies.append(strategies_vector) - self.strategy_map[node] = strategies_vector - - # remove no strategy nodes - remove_list = [] - for strategies_vector in self.leaf_strategies: - if len(strategies_vector) == 0: - remove_list.append(strategies_vector.node) - for node in remove_list: - if node.strategies_vector in self.leaf_strategies: - self.leaf_strategies.remove(node.strategies_vector) - if node in self.strategy_map: - self.strategy_map.pop(node) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py deleted file mode 100644 index 96d96a459..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py +++ /dev/null @@ -1,96 +0,0 @@ -from copy import deepcopy -from pickletools import optimize - -import pytest -import torch -import torch.nn as nn -from torch.fx import GraphModule - -from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer - - -class ConvModel(nn.Module): - - def __init__(self, c_in, c_out): - super().__init__() - self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3) - self.relu = nn.ReLU() - - def forward(self, x): - x = x * 2 - x = self.conv1(x) - x = x / 2 - x = self.relu(x) - return x - - -def test_cost_graph(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - entire_shape = torch.Size((4, 16, 64, 64)) - - tracer = ColoTracer() - model = ConvModel(16, 32) - input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} - - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {}) - # %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {}) - # %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {}) - # return relu - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - - # (x, mul):{(0, 0): 0} - # (mul, conv1):{(0, 0): 65547.1, (0, 1): 65547.1, (0, 2): 65547.1, (0, 3): 65547.1, (0, 4): 131105.30000000002, (0, 5): 131105.30000000002, (0, 6): 65547.1, (0, 7): 65547.1, (0, 8): 65547.1, (0, 9): 65547.1, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 131105.30000000002, (0, 14): 131105.30000000002} - # (conv1, truediv):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): inf, (11, 0): inf, (12, 0): inf, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): 0, (6, 1): inf, (7, 1): inf, (8, 1): inf, (9, 1): inf, (10, 1): inf, (11, 1): inf, (12, 1): inf, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): inf, (11, 2): inf, (12, 2): inf, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): inf, (11, 3): inf, (12, 3): inf, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): inf, (11, 4): inf, (12, 4): inf, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): inf, (11, 5): inf, (12, 5): inf, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): inf, (11, 7): inf, (12, 7): inf, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): inf, (9, 8): inf, (10, 8): inf, (11, 8): inf, (12, 8): inf, (13, 8): inf, (14, 8): 0} - # (truediv, relu):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): inf, (6, 1): inf, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): inf, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): inf, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): 0} - # (relu, output):{(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002} - cost_graph = CostGraph(strategies_constructor.leaf_strategies) - - # construct all node pairs - all_node_pairs = [] - - for node in graph.nodes: - if node.op == 'output': - continue - for child in node.users.keys(): - all_node_pairs.append((node, child)) - - for node_pair in all_node_pairs: - assert node_pair in cost_graph.edge_costs - - # construct merged node pairs - merged_node_pairs = [] - node_list = list(graph.nodes) - # add (conv1_weight, conv2d), (conv1_bias, view), (conv2d, add), (view, add), (add, output), (x, conv2d) into check node pairs - merged_node_pairs.append((node_list[0], node_list[4])) - merged_node_pairs.append((node_list[2], node_list[4])) - merged_node_pairs.append((node_list[3], node_list[5])) - merged_node_pairs.append((node_list[5], node_list[6])) - merged_node_pairs.append((node_list[4], node_list[6])) - merged_node_pairs.append((node_list[6], node_list[-1])) - cost_graph.simplify_graph() - for node_pair in all_node_pairs: - if node_pair in merged_node_pairs: - assert node_pair in cost_graph.edge_costs - else: - assert node_pair not in cost_graph.edge_costs - - -if __name__ == '__main__': - test_cost_graph() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py deleted file mode 100644 index 2d3e71551..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py +++ /dev/null @@ -1,118 +0,0 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn -import pytest - -from colossalai.fx.proxy import ColoProxy -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.batch_norm_handler import BatchNormHandler -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.device.device_mesh import DeviceMesh - - -class BNModel(nn.Module): - - def __init__(self, c): - super().__init__() - self.bn = nn.BatchNorm2d(c) - - def forward(self, x): - x = x * 2 - x = self.bn(x) - return x - - -def test_bn_handler(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - entire_shape = torch.Size((4, 16, 64, 64)) - - tracer = ColoTracer() - model = BNModel(16) - input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %bn : [#users=1] = call_module[target=bn](args = (%mul,), kwargs = {}) - # return bn - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - # [x, mul, bn, output] - nodes = [node for node in gm.graph.nodes] - - # find the sharding strategies for the input node of the bn node - # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] - strategies_vector_for_input = StrategiesVector(nodes[1]) - sharding_option = (None, 0, 1) - for first_sharding_index in sharding_option: - for second_sharding_index in sharding_option: - if first_sharding_index is not None and second_sharding_index == first_sharding_index: - continue - if first_sharding_index is None: - first_dim_spec = _DimSpec([]) - else: - first_dim_spec = _DimSpec([first_sharding_index]) - - if second_sharding_index is None: - second_dim_spec = _DimSpec([]) - else: - second_dim_spec = _DimSpec([second_sharding_index]) - - replica_dim_spec = _DimSpec([]) - sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec] - sharding_spec = ShardingSpec(device_mesh=device_mesh, - entire_shape=entire_shape, - sharding_sequence=sharding_sequence) - strategy_name = str(sharding_spec.sharding_sequence) - sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec) - strategies_vector_for_input.append(sharding_strategy) - setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) - - # generate bn strategy - strategies_vector = StrategiesVector(node=nodes[2]) - bn_handler = BatchNormHandler( - node=nodes[2], - device_mesh=device_mesh, - strategies_vector=strategies_vector, - ) - bn_handler.register_strategy() - # ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01', - # 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN'] - strategy_name_list = [strategy.name for strategy in bn_handler.strategies_vector] - - # RS = RS x S and strategies based on it, such as - # SS = RS x S - assert 'RS0 = RS0 x S0' in strategy_name_list - assert 'S1S0 = RS0 x S0' in strategy_name_list - assert 'RS1 = RS1 x S1' in strategy_name_list - assert 'S0S1 = RS1 x S1' in strategy_name_list - - # RR = RR x R and strategies based on it, such as - # SR = SR x R - assert 'RR = RR x R' in strategy_name_list - assert 'S0R = RR x R' in strategy_name_list - assert 'S1R = RR x R' in strategy_name_list - assert 'S01R = RR x R' in strategy_name_list - - # RS01 = RS01 x S01 - assert 'RS01 = RS01 x S01' in strategy_name_list - - # SR = SR x R WITH SYNC_BN - assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list - assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list - - # SS = SS x S WITH SYNC_BN - assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list - assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list - - # S01R = S01R x R WITH SYNC_BN - assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list - - -if __name__ == '__main__': - test_bn_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py deleted file mode 100644 index 7adc211cf..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py +++ /dev/null @@ -1,75 +0,0 @@ -from cProfile import run - -import pytest -import torch -import torch.nn as nn -from torch.fx import GraphModule - -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.testing.pytest_wrapper import run_on_environment_flag - - -class ConvModel(nn.Module): - - def __init__(self, c_in, c_out): - super().__init__() - self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1) - self.conv2 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2) - - def forward(self, x): - x1 = self.conv1(x) - x2 = x1 + 1 - x1 = torch.reshape(x1, [1, -1, 64, 1]) - x3 = self.conv2(x1) - x3 = torch.reshape(x3, [4, 1, 64, -1]) - x = x1 + x3 - - return x - - -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_conv_handler(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - - tracer = ColoTracer() - model = ConvModel(16, 32) - input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %conv1 : [#users=2] = call_module[target=conv1](args = (%x,), kwargs = {}) - # %add : [#users=0] = call_function[target=operator.add](args = (%conv1, 1), kwargs = {}) - # %reshape : [#users=2] = call_function[target=torch.reshape](args = (%conv1, [1, -1, 64, 1]), kwargs = {}) - # %conv2 : [#users=1] = call_module[target=conv2](args = (%reshape,), kwargs = {}) - # %reshape_1 : [#users=1] = call_function[target=torch.reshape](args = (%conv2, [4, 1, 64, -1]), kwargs = {}) - # %add_1 : [#users=1] = call_function[target=operator.add](args = (%reshape, %reshape_1), kwargs = {}) - # return add_1 - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - # [x, conv1, add, reshape, conv2, reshape_1, add_1, output] - nodes = [node for node in gm.graph.nodes] - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - - strategies_constructor.build_strategies_and_cost() - strategy_map = strategies_constructor.strategy_map - # check a tensor add with a scalar case - conv1_strategies = strategy_map[nodes[1]] - add_strategies = strategy_map[nodes[2]] - add_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in add_strategies] - for strategy in conv1_strategies: - assert strategy.output_sharding_spec.sharding_sequence in add_strategies_cover_list - - # check two tensors element-wise add case - add_1_strategies = strategy_map[nodes[6]] - assert len(add_1_strategies) == 25 - - -if __name__ == '__main__': - test_conv_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py deleted file mode 100644 index 426d179f1..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py +++ /dev/null @@ -1,54 +0,0 @@ -import pytest -import torch -import torch.nn as nn -from torch.fx import GraphModule - -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.testing.pytest_wrapper import run_on_environment_flag - - -class MatmulModel(nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, x1, x2): - x = torch.matmul(x1, x2) - - return x - - -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_conv_handler(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - - tracer = ColoTracer() - model = MatmulModel() - input_sample = {'x1': torch.rand(4, 4, 8).to('meta'), 'x2': torch.rand(4, 1, 8, 4).to('meta')} - # graph(): - # %x1 : torch.Tensor [#users=1] = placeholder[target=x1] - # %x2 : torch.Tensor [#users=1] = placeholder[target=x2] - # %matmul : [#users=1] = call_function[target=torch.matmul](args = (%x1, %x2), kwargs = {}) - # return matmul - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - # [x1, x2, matmul, output] - nodes = [node for node in gm.graph.nodes] - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - - strategies_constructor.build_strategies_and_cost() - strategy_map = strategies_constructor.strategy_map - matmul_strategies = strategy_map[nodes[2]] - assert len(matmul_strategies) == 30 - - -if __name__ == '__main__': - test_conv_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py deleted file mode 100644 index 9342e06a0..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py +++ /dev/null @@ -1,90 +0,0 @@ -import pytest -import torch -import torch.nn as nn -from torch.fx import GraphModule - -from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import ConvHandler -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.proxy import ColoProxy -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec - - -class ConvModel(nn.Module): - - def __init__(self, c_in, c_out): - super().__init__() - self.conv = nn.Conv2d(c_in, c_out, kernel_size=3) - - def forward(self, x): - x = x * 2 - x = self.conv(x) - return x - - -def test_conv_handler(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - entire_shape = torch.Size((4, 16, 64, 64)) - - tracer = ColoTracer() - model = ConvModel(16, 32) - input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %conv_weight : [#users=1] = get_attr[target=conv.weight] - # %conv_bias : [#users=1] = get_attr[target=conv.bias] - # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)}) - # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) - # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) - # return add - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - - strategies_constructor.build_strategies_and_cost() - conv_node = list(graph.nodes)[4] - # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'] - strategy_name_list = [strategy.name for strategy in conv_node.strategies_vector] - - # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list - - # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list - - # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list - - # RS = RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list - - # RR= RR x RR - assert 'RR = RR x RR' in strategy_name_list - - # SR = SR x RR - assert 'S0R = S0R x RR' in strategy_name_list - assert 'S1R = S1R x RR' in strategy_name_list - assert 'S01R = S01R x RR' in strategy_name_list - - # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list - assert 'RR = RS01 x S01R' in strategy_name_list - - -if __name__ == '__main__': - test_conv_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py deleted file mode 100644 index 0a2dba161..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -import torch -import torch.nn as nn -from torch.fx import GraphModule - -from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.dot_handler import DotHandler -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.proxy import ColoProxy -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec - - -class LinearModel(nn.Module): - - def __init__(self, in_features, out_features): - super().__init__() - self.linear = nn.Linear(in_features, out_features) - - def forward(self, x): - x = x * 2 - x = self.linear(x) - return x - - -@pytest.mark.skip('F.linear is not supported in deprecated handler') -def test_dot_handler(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - entire_shape = torch.Size((4, 8)) - - tracer = ColoTracer() - model = LinearModel(8, 16) - input_sample = {'x': torch.rand(4, 8).to('meta')} - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %linear_weight : [#users=1] = get_attr[target=linear.weight] - # %linear_bias : [#users=1] = get_attr[target=linear.bias] - # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %linear_weight), kwargs = {}) - # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) - # return add - graph = tracer.trace(root=model, meta_args=input_sample) - - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - - strategies_constructor.build_strategies_and_cost() - linear_node = list(graph.nodes)[4] - - # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] - strategy_name_list = [strategy.name for strategy in linear_node.strategies_vector] - - # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list - - # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list - - # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list - - # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list - - # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list - - -if __name__ == '__main__': - test_dot_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py deleted file mode 100644 index 40e227cb5..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn -import pytest -from colossalai.auto_parallel.tensor_shard.deprecated import sharding_strategy - -from colossalai.fx.proxy import ColoProxy -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.layer_norm_handler import LayerNormHandler -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.device.device_mesh import DeviceMesh - - -class LNModel(nn.Module): - - def __init__(self, c): - super().__init__() - self.ln = nn.LayerNorm(c) - - def forward(self, x): - x = x * 2 - x = self.ln(x) - return x - - -def test_bn_handler(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - entire_shape = torch.Size((4, 4, 128)) - - tracer = ColoTracer() - model = LNModel(128) - input_sample = {'x': torch.rand(4, 4, 128).to('meta')} - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %ln : [#users=1] = call_module[target=ln](args = (%mul,), kwargs = {}) - # return ln - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - # [x, mul, ln, output] - nodes = [node for node in gm.graph.nodes] - sharding_spec_for_input = ShardingSpec(device_mesh, entire_shape, {}) - sharding_strategy_for_input = ShardingStrategy('node_1', sharding_spec_for_input) - strategies_vector_for_input = StrategiesVector(nodes[1]) - strategies_vector_for_input.append(sharding_strategy_for_input) - setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) - - # generate bn strategy - strategies_vector = StrategiesVector(node=nodes[2]) - ln_handler = LayerNormHandler( - node=nodes[2], - device_mesh=device_mesh, - strategies_vector=strategies_vector, - ) - ln_handler.register_strategy() - # ['[S0, R, R] = [S0, R, R] x [R]', '[R, S0, R] = [R, S0, R] x [R]', '[S1, R, R] = [S1, R, R] x [R]', '[R, S1, R] = [R, S1, R] x [R]', - # '[S0, S1, R] = [S0, S1, R] x [R]', '[S1, S0, R] = [S1, S0, R] x [R]', '[S01, R, R] = [S01, R, R] x [R]', '[R, S01, R] = [R, S01, R] x [R]', 'RR = RR x R'] - strategy_name_list = [strategy.name for strategy in ln_handler.strategies_vector] - - assert len(strategy_name_list) == 9 - - -if __name__ == '__main__': - test_bn_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py deleted file mode 100644 index ac9df4cd8..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import torch.nn as nn -from torch.fx import GraphModule - -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer - - -class ConvModel(nn.Module): - - def __init__(self, c_in, c_out): - super().__init__() - self.conv = nn.Conv2d(c_in, c_out, kernel_size=3) - - def forward(self, x): - x = self.conv(x) - x = torch.flatten(x) - return x - - -def test_conv_handler(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - - tracer = ColoTracer() - model = ConvModel(16, 32) - input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %conv_weight : [#users=1] = get_attr[target=conv.weight] - # %conv_bias : [#users=1] = get_attr[target=conv.bias] - # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)}) - # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) - # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) - # %flatten : [#users=1] = call_function[target=torch.flatten](args = (%add,), kwargs = {}) - # return flatten - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - # [x, conv, flatten, output] - nodes = [node for node in gm.graph.nodes] - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - - strategies_constructor.build_strategies_and_cost() - strategy_map = strategies_constructor.strategy_map - add_strategies = strategy_map[nodes[5]] - flatten_strategies = strategy_map[nodes[6]] - flatten_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in flatten_strategies] - for strategy in add_strategies: - assert strategy.output_sharding_spec.sharding_sequence in flatten_strategies_cover_list - - -if __name__ == '__main__': - test_conv_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py deleted file mode 100644 index 294a59fc8..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn -import pytest - -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.pytest_wrapper import run_on_environment_flag - - -class ConvModel(nn.Module): - - def __init__(self, dim_in, dim_out): - super().__init__() - self.dim_in = dim_in - self.dim_out = dim_out - - def forward(self, condition, x, y): - output = torch.where(condition, x, y) - - return output - - -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_where_handler(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - - tracer = ColoTracer() - model = ConvModel(16, 32) - input_sample = { - 'condition': torch.rand(16, 32).to('meta'), - 'x': torch.rand(16, 32).to('meta'), - 'y': torch.rand(16, 32).to('meta') - } - # graph(): - # %condition : torch.Tensor [#users=1] = placeholder[target=condition] - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %y : torch.Tensor [#users=1] = placeholder[target=y] - # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {}) - # return where - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - - # [condition, x, y, where, output] - nodes = [node for node in gm.graph.nodes] - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - - strategies_constructor.build_strategies_and_cost() - strategy_map = strategies_constructor.strategy_map - # check a tensor add with a scalar case - where_node = strategy_map[nodes[3]] - # ['[S0, S1] = [S0, S1] x [S0, S1] x [S0, S1]', '[S1, S0] = [S1, S0] x [S1, S0] x [S1, S0]', '[S01, R] = [S01, R] x [S01, R] x [S01, R]', - # '[R, S01] = [R, S01] x [R, S01] x [R, S01]', '[S0, R] = [S0, R] x [S0, R] x [S0, R]', '[R, S0] = [R, S0] x [R, S0] x [R, S0]', - # '[S1, R] = [S1, R] x [S1, R] x [S1, R]', '[R, S1] = [R, S1] x [R, S1] x [R, S1]', '[R, R] = [R, R] x [R, R] x [R, R]'] - assert len(where_node) == 9 - - -if __name__ == '__main__': - test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py deleted file mode 100644 index 3286b325c..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py +++ /dev/null @@ -1,86 +0,0 @@ -from functools import partial -import pytest -import torch -import torch.multiprocessing as mp -from torch.fx import GraphModule -import torch.nn as nn -import pytest -from colossalai.initialize import launch -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.logging import disable_existing_loggers -from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph -from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor - -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass -from colossalai.auto_parallel.tensor_shard.deprecated import Solver -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.testing.pytest_wrapper import run_on_environment_flag - - -class ConvModel(nn.Module): - - def __init__(self, c_in, c_out): - super().__init__() - self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False) - - def forward(self, x): - x = self.conv(x) - return x - - -def check_apply(rank, world_size, port): - disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - input = torch.rand(4, 4, 4, 4).cuda() - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = torch.Size((4, 4, 8, 8)) - - tracer = ColoTracer() - model = ConvModel(4, 4).cuda() - origin_output = model(input) - input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')} - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) - # return conv - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - - cost_graph = CostGraph(strategies_constructor.leaf_strategies) - cost_graph.simplify_graph() - graph_analyser = GraphAnalyser(gm) - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) - ret = solver.call_solver_serialized_args() - solution = list(ret[0]) - sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh) - shape_consistency_pass(gm) - gm.recompile() - nodes = [node for node in gm.graph.nodes] - # TODO: wrap the gm to avoid the influence of the user training code - output = gm(input, sharding_spec_dict, origin_spec_dict) - assert output.equal(origin_output) - - -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_apply(): - world_size = 4 - run_func = partial(check_apply, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_apply() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py deleted file mode 100644 index baa70727a..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py +++ /dev/null @@ -1,79 +0,0 @@ -from copy import deepcopy - -import pytest -import torch -import torch.nn as nn -from torch.fx import GraphModule - -from colossalai.auto_parallel.tensor_shard.deprecated import Solver -from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph -from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.testing.pytest_wrapper import run_on_environment_flag - - -class ConvModel(nn.Module): - - def __init__(self, c_in, c_out): - super().__init__() - self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3) - self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3) - self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3) - self.relu = nn.ReLU() - - def forward(self, x): - x = x * 2 - x = self.conv1(x) - x = self.conv2(x) - x = x / 2 - x = self.conv3(x) - x = self.relu(x) - return x - - -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_solver(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() - - tracer = ColoTracer() - model = ConvModel(16, 32) - input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} - - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {}) - # %conv2 : [#users=1] = call_module[target=conv2](args = (%conv1,), kwargs = {}) - # %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv2, 2), kwargs = {}) - # %conv3 : [#users=1] = call_module[target=conv3](args = (%truediv,), kwargs = {}) - # %relu : [#users=1] = call_module[target=relu](args = (%conv3,), kwargs = {}) - # return relu - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - - cost_graph = CostGraph(strategies_constructor.leaf_strategies) - cost_graph.simplify_graph() - graph_analyser = GraphAnalyser(gm) - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) - ret = solver.call_solver_serialized_args() - - # [ 0 0 13 13 13 13 13 0] - strategies_combination_list = ret[0] - assert solver.leaf_strategies[2][13].name == 'S01R = S01R x RR' - - -if __name__ == '__main__': - test_solver() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py deleted file mode 100644 index e90d6b153..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn -import pytest - -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph -from copy import deepcopy -from colossalai.auto_parallel.tensor_shard.deprecated import Solver -import transformers -from colossalai.auto_parallel.tensor_shard.deprecated.constants import * -from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.testing.pytest_wrapper import run_on_environment_flag - -BATCH_SIZE = 8 -SEQ_LENGHT = 8 - - -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_cost_graph(): - physical_mesh_id = torch.arange(0, 8) - mesh_shape = (2, 4) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() - - tracer = ColoTracer() - config = transformers.GPT2Config(n_position=1024, n_layer=1, n_head=12) - model = transformers.GPT2LMHeadModel(config=config) - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - meta_args = {k: v.to('meta') for k, v in kwargs.items()} - - graph = tracer.trace(root=model, meta_args=meta_args) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - graph_analyser = GraphAnalyser(gm) - liveness_list = graph_analyser.liveness_analysis() - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - print(graph) - strategies_constructor.build_strategies_and_cost() - for check_node, strategies_vector in strategies_constructor.strategy_map.items(): - print(check_node, len(strategies_vector)) - cost_graph = CostGraph(strategies_constructor.leaf_strategies) - cost_graph.simplify_graph() - # solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0) - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) - - ret = solver.call_solver_serialized_args() - print(ret) - strategies_list = list(ret[0]) - print(strategies_list) - computation_cost = 0 - communication_cost = 0 - memory_cost = 0 - nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] - for index, node in enumerate(nodes): - print(node.name, node.strategies_vector[strategies_list[index]].name) - computation_cost += node.strategies_vector[strategies_list[index]].compute_cost - communication_cost += node.strategies_vector[strategies_list[index]].communication_cost - node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost - if isinstance(node_memory_cost, tuple): - node_memory_cost = node_memory_cost[0] - memory_cost += node_memory_cost - - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') - - -if __name__ == '__main__': - test_cost_graph() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py deleted file mode 100644 index 415156ed6..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn -import pytest - -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph -from copy import deepcopy -from colossalai.auto_parallel.tensor_shard.deprecated import Solver -from torchvision.models import resnet34, resnet50 -from colossalai.auto_parallel.tensor_shard.deprecated.constants import * -from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.testing.pytest_wrapper import run_on_environment_flag - - -class MLP(torch.nn.Module): - - def __init__(self, dim: int): - super().__init__() - self.linear1 = torch.nn.Linear(dim, dim * 4) - self.linear2 = torch.nn.Linear(dim * 4, dim) - self.dropout = torch.nn.Dropout(0) - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.linear1(x) - x = self.dropout(x) - x = self.relu(x) - x = self.linear2(x) - return x - - -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_cost_graph(): - physical_mesh_id = torch.arange(0, 8) - mesh_shape = (2, 4) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() - - tracer = ColoTracer() - model = MLP(32) - - input_sample = {'x': torch.rand(16, 32).to('meta')} - - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {}) - # %dropout : [#users=1] = call_module[target=dropout](args = (%linear1,), kwargs = {}) - # %relu : [#users=1] = call_module[target=relu](args = (%dropout,), kwargs = {}) - # %linear2 : [#users=1] = call_module[target=linear2](args = (%relu,), kwargs = {}) - # return linear2 - graph = tracer.trace(root=model, meta_args=input_sample) - - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - graph_analyser = GraphAnalyser(gm) - liveness_list = graph_analyser.liveness_analysis() - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - - cost_graph = CostGraph(strategies_constructor.leaf_strategies) - cost_graph.simplify_graph() - # # megatron mode if no memory constraints - # solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) - # all sharding on out feature dim if memory budget is not sufficient for megatron mode - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=5500.0) - - ret = solver.call_solver_serialized_args() - strategies_list = list(ret[0]) - computation_cost = 0 - communication_cost = 0 - memory_cost = 0 - for index, node in enumerate(graph.nodes): - print(node.name, node.strategies_vector[strategies_list[index]].name) - computation_cost += node.strategies_vector[strategies_list[index]].compute_cost - communication_cost += node.strategies_vector[strategies_list[index]].communication_cost - node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost - if isinstance(node_memory_cost, tuple): - node_memory_cost = node_memory_cost[0] - memory_cost += node_memory_cost - - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') - - -if __name__ == '__main__': - test_cost_graph() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py deleted file mode 100644 index 9be1a5d96..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py +++ /dev/null @@ -1,103 +0,0 @@ -from copy import deepcopy - -import pytest -import torch -import torch.nn as nn -from torch.fx import GraphModule - -from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import CONV_STRATEGIES_LIST -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.proxy import ColoProxy -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec - - -class ConvModel(nn.Module): - - def __init__(self, c_in, c_out): - super().__init__() - self.conv = nn.Conv2d(c_in, c_out, kernel_size=3) - - def forward(self, x): - x = x * 2 - x = self.conv(x) - return x - - -def test_strategies_constructor(): - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - entire_shape = torch.Size((4, 16, 64, 64)) - - tracer = ColoTracer() - model = ConvModel(16, 32) - input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %conv_weight : [#users=1] = get_attr[target=conv.weight] - # %conv_bias : [#users=1] = get_attr[target=conv.bias] - # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)}) - # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) - # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) - # return add - graph = tracer.trace(root=model, meta_args=input_sample) - print(graph) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - - solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - - assert strategies_constructor.leaf_strategies == [] - assert strategies_constructor.strategy_map == {} - strategies_constructor.build_strategies_and_cost() - - # check leaf_strategies - - # In fast mode, placeholder node only has replica strategy. - assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder' - - # Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec. - assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]_0' - - # Third node is conv. - conv_check_list = deepcopy(CONV_STRATEGIES_LIST) - for strategy in strategies_constructor.leaf_strategies[4]: - conv_check_list.remove(strategy.name) - assert len(conv_check_list) == 0 - - # In fast mode, output node only has replica strategy. - assert strategies_constructor.leaf_strategies[7][0].name == 'Replica Output' - - # check strategy_map - - nodes = [node for node in graph.nodes] - # In fast mode, placeholder node only has replica strategy. - x = nodes[0] - assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder' - - # Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec. - mul = nodes[1] - assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0' - - # fifth node is conv. - conv = nodes[4] - conv_check_list = deepcopy(CONV_STRATEGIES_LIST) - for strategy in strategies_constructor.strategy_map[conv]: - conv_check_list.remove(strategy.name) - assert len(conv_check_list) == 0 - - # In fast mode, output node only has replica strategy. - output = nodes[-1] - assert strategies_constructor.strategy_map[output][0].name == 'Replica Output' - - -if __name__ == '__main__': - test_strategies_constructor() From d03f4429c1155eb806d9b0763a43dfe4184a98f9 Mon Sep 17 00:00:00 2001 From: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com> Date: Wed, 15 Feb 2023 09:55:53 +0800 Subject: [PATCH 16/30] add ci (#2641) --- examples/images/diffusion/README.md | 6 ++++++ examples/images/diffusion/main.py | 2 ++ examples/images/diffusion/test_ci.sh | 17 +++++++++++++++++ 3 files changed, 25 insertions(+) mode change 100644 => 100755 examples/images/diffusion/test_ci.sh diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index 952da5d1c..15932f1f5 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -92,6 +92,12 @@ cd ColossalAI CUDA_EXT=1 pip install . ``` +#### Step 3:Accelerate with flash attention by xformers(Optional) + +``` +pip install xformers +``` + ### Option #2: Use Docker To use the stable diffusion Docker image, you can either build using the provided the [Dockerfile](./docker/Dockerfile) or pull a Docker image from our Docker hub. diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py index 5f166aa1f..4dd88a5ec 100644 --- a/examples/images/diffusion/main.py +++ b/examples/images/diffusion/main.py @@ -539,6 +539,8 @@ if __name__ == "__main__": raise ValueError("-n/--name and -r/--resume cannot be specified both." "If you want to resume training in a new log folder, " "use -n/--name in combination with --resume_from_checkpoint") + + ckpt = None if opt.resume: rank_zero_info("Resuming from {}".format(opt.resume)) if not os.path.exists(opt.resume): diff --git a/examples/images/diffusion/test_ci.sh b/examples/images/diffusion/test_ci.sh old mode 100644 new mode 100755 index e69de29bb..51ceeb41d --- a/examples/images/diffusion/test_ci.sh +++ b/examples/images/diffusion/test_ci.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -euxo pipefail + +conda env create -f environment.yaml + +conda activate ldm + +conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch +pip install transformers diffusers invisible-watermark + +CUDA_EXT=1 pip install colossalai + +pip install pytorch-lightning + +wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt + +python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt 512-base-ema.ckpt From 89f8975fb8f0ca7e637fce33f0d13248094ccb4d Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 15 Feb 2023 10:12:55 +0800 Subject: [PATCH 17/30] [workflow] fixed tensor-nvme build caching (#2711) --- .github/workflows/build_on_pr.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index c7882db6e..f595e6773 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -96,6 +96,7 @@ jobs: - name: Store TensorNVMe Cache run: | + cd TensorNVMe cp -p -r ./build /github/home/tensornvme_cache/ - name: Checkout Colossal-AI From cb2c6a2415d6da40ad694e3c7a7b3dae647ac073 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 15 Feb 2023 10:36:19 +0800 Subject: [PATCH 18/30] [autoparallel] refactor runtime pass (#2644) * [autoparallel] refactor runtime pass * add unit test * polish --- colossalai/auto_parallel/passes/constants.py | 5 + .../passes/runtime_preparation_pass.py | 439 +++++++++--------- .../test_pass/test_node_converting_pass.py | 54 +++ .../test_size_value_converting_pass.py | 65 +++ .../test_node_handler/test_linear_handler.py | 3 - 5 files changed, 352 insertions(+), 214 deletions(-) create mode 100644 tests/test_auto_parallel/test_pass/test_node_converting_pass.py create mode 100644 tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py diff --git a/colossalai/auto_parallel/passes/constants.py b/colossalai/auto_parallel/passes/constants.py index b86088474..485a87492 100644 --- a/colossalai/auto_parallel/passes/constants.py +++ b/colossalai/auto_parallel/passes/constants.py @@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [ torch.nn.ReLU, torch.nn.Softmax, ] + +# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args. +# This list could be extended if any other method has the same +# argument style as view and reshape. +SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape] diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index ecf3f1f18..bb419be35 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -19,6 +19,8 @@ from colossalai.tensor.comm_spec import _all_reduce from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec +from .constants import SHAPE_ARGUMENT_OPS + shape_consistency_manager = ShapeConsistencyManager() @@ -51,23 +53,16 @@ def size_processing(size: Union[int, torch.Size], return size -def _solution_annotatation(gm: torch.fx.GraphModule, - solution: List[int], - strategies_constructor: StrategiesConstructor = None): +def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], + strategies_constructor: StrategiesConstructor): """ This method is used to stick the solution strategy to the nodes and add the information required in runtime into graph as placeholder nodes. """ mod_graph = gm.graph - # TODO: In future PR, strategies_constructor should be a required argument, - # instead of optional argument. This is because we don't need to consider nodes with - # no strategy in runtime preparation pass. - if strategies_constructor is not None: - nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] - no_strategy_nodes = strategies_constructor.no_strategy_nodes - else: - nodes = tuple(mod_graph.nodes) - no_strategy_nodes = [] + + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + no_strategy_nodes = strategies_constructor.no_strategy_nodes # the dict to get origin sharding spec of node origin_node_sharding_spec_dict = {} @@ -97,6 +92,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule, target_sharding_specs.append(target_sharding_spec) sharding_spec_convert_dict[index] = target_sharding_specs setattr(node, 'target_sharding_specs', target_sharding_specs) + # the get_attr node strategy is kind of pending strategy, which means we will change it # to the same strategy of the user node. if node.op == 'get_attr': @@ -134,7 +130,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule, return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict -def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): +def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): """ In the auto parallel system, tensors may get shard on different devices, so the size of tensors need to be converted to the size of original tensor and managed by the users, such as torch.view, @@ -145,6 +141,80 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): nodes = tuple(mod_graph.nodes) node_pairs = {} + # DeviceMesh information instructs the scaling of the size value + device_mesh_info = {} + for dim, dim_size in enumerate(device_mesh.mesh_shape): + device_mesh_info[dim] = dim_size + + def _extract_target_dim(node): + ''' + A helper function to etract the target dimension from size node. + There are two usages of torch.Tensor.size: + 1. tensor.size() + 2. tensor.size(dim) + + If a target_dim is assigned, then the output will be in type of int, instead of torch.Size. + Otherwise, the output will be in type of torch.Size and this function will return None. + ''' + target_dim = None + if len(node.args) > 1: + target_dim = node.args[1] + if target_dim < 0: + target_dim += node.args[0]._meta_data.dim() + return target_dim + + def _post_processing(node, size_processing_node): + ''' + This function is used to process the dependency between the size node and its users after + inserting the size_process_node. + ''' + # store original node and processing node pair in node_pairs dictioanry + # It will be used to replace the original node with processing node in slice object + node_pairs[node] = size_processing_node + size_processing_node._meta_data = node._meta_data + if 'activation_checkpoint' in node.meta: + size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] + + user_list = list(node.users.keys()) + for user in user_list: + if user == size_processing_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with size_processing_node + new_args[new_args.index(node)] = size_processing_node + user.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with size_processing_node + new_kwargs[str(node)] = size_processing_node + user.kwargs = new_kwargs + + def _update_slice_object_args(slice_object): + ''' + This function is used to update the slice object argument list. + If the slice object contains the Node argument, then the size node will be replaced with + ''' + if isinstance(slice_object, slice): + start = slice_object.start + stop = slice_object.stop + step = slice_object.step + if start in node_pairs: + start = node_pairs[start] + if stop in node_pairs: + stop = node_pairs[stop] + if step in node_pairs: + step = node_pairs[step] + return slice(start, stop, step) + elif isinstance(slice_object, int): + if slice_object in node_pairs: + return node_pairs[slice_object] + else: + return slice_object + else: + raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}") + for node in nodes: if node.op == 'call_method' and node.target == 'size': @@ -154,49 +224,15 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): sharding_spec = node.args[0].sharding_spec dim_partition_dict = sharding_spec.dim_partition_dict - # there are two usages of torch.Tensor.size: - # tensor.size() - # tensor.size(dim) - # if a target_dim is assigned, then the output will be - # in type of int, instead of torch.Size - target_dim = None - if len(node.args) > 1: - target_dim = node.args[1] - if target_dim < 0: - target_dim += node.args[0]._meta_data.dim() - - # DeviceMesh information instructs the scaling of the size value - device_mesh_info = {} - for dim, dim_size in enumerate(device_mesh.mesh_shape): - device_mesh_info[dim] = dim_size + target_dim = _extract_target_dim(node) + # insert size_processing node with mod_graph.inserting_after(node): size_processing_node = mod_graph.create_node('call_function', size_processing, args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name)) - # store original node and processing node pair in node_pairs dictioanry - # It will be used to replace the original node with processing node in slice object - node_pairs[node] = size_processing_node - size_processing_node._meta_data = node._meta_data - if 'activation_checkpoint' in node.meta: - size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] - - user_list = list(node.users.keys()) - for user in user_list: - if user == size_processing_node: - continue - new_args = list(user.args) - new_kwargs = dict(user.kwargs) - # the origin node may be a positional argument or key word argument of user node - if node in new_args: - # substitute the origin node with size_processing_node - new_args[new_args.index(node)] = size_processing_node - user.args = tuple(new_args) - elif str(node) in new_kwargs: - # substitute the origin node with size_processing_node - new_kwargs[str(node)] = size_processing_node - user.kwargs = new_kwargs + _post_processing(node, size_processing_node) if node.op == 'call_function' and node.target == operator.getitem: @@ -217,14 +253,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): # In this pass, we need process the last two cases because # node arguments may potentially appear in these cases. if isinstance(getitem_index, slice): - new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step - if getitem_index.start in node_pairs: - new_start = node_pairs[getitem_index.start] - elif getitem_index.stop in node_pairs: - new_stop = node_pairs[getitem_index.stop] - elif getitem_index.step in node_pairs: - new_step = node_pairs[getitem_index.step] - new_slice_item = slice(new_start, new_stop, new_step) + new_slice_item = _update_slice_object_args(getitem_index) new_args = (node.args[0], new_slice_item) node.args = new_args @@ -237,16 +266,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): if slice_item is None: new_slice_items.append(None) continue - - new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step - - if slice_item.start in node_pairs: - new_start = node_pairs[slice_item.start] - elif slice_item.stop in node_pairs: - new_stop = node_pairs[slice_item.stop] - elif slice_item.step in node_pairs: - new_step = node_pairs[slice_item.step] - new_slice_item = slice(new_start, new_stop, new_step) + new_slice_item = _update_slice_object_args(slice_item) new_slice_items.append(new_slice_item) new_args = (node.args[0], tuple(new_slice_items)) @@ -255,104 +275,109 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): return gm -def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): +def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): """ This pass will process node args to adapt the distributed tensor layout. """ mod_graph = gm.graph nodes = tuple(mod_graph.nodes) + def _extract_info_from_sharding_spec(sharding_spec): + ''' + This function is used to extract the dim_partition_dict and device_mesh from + sharding spec instance or a list of sharding spec. + ''' + if isinstance(sharding_spec, ShardingSpec): + dim_partition_dict = sharding_spec.dim_partition_dict + device_mesh = sharding_spec.device_mesh + return dim_partition_dict, device_mesh + if sharding_spec is None: + return None, None + assert isinstance(sharding_spec, + (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' + + device_mesh = sharding_spec[0].device_mesh + dim_partition_dict = [] + for element in sharding_spec: + dim_partition_dict.append(_extract_info_from_sharding_spec(element)) + return dim_partition_dict, sharding_spec + + def _process_node_arguments(node): + new_args = [] + for arg in node.args: + # There are two args style: + # 1. (input, *shape) + # 2. (input, shape) + # We will extract the elements from shape and add them into the new_args + # Finally, the args style of new_args will be unified to (input, *shape) + if isinstance(arg, Node): + if isinstance(arg._meta_data, (tuple, list)): + new_args.extend(arg._meta_data) + elif isinstance(arg._meta_data, int): + new_args.append(arg._meta_data) + else: + new_args.append(arg) + else: + assert isinstance(arg, + (int, tuple, list)), 'The argument in view node should be either type of Node or int.' + if isinstance(arg, (tuple, list)): + new_args.extend(arg) + else: + new_args.append(arg) + return new_args + + def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): + new_args = _process_node_arguments(node) + if node.op == 'call_method': + args_to_process = list(new_args[1:]) + else: + args_to_process = list(new_args) + for dim, shard_dims in dim_partition_dict.items(): + total_shard_size = 1 + for shard_dim in shard_dims: + total_shard_size *= device_mesh.shape[shard_dim] + + # we will skip the dim with -1 value + if args_to_process[dim] == -1: + continue + else: + # TODO: add assertion here to make sure the dim size is divisible by total_shard_size + args_to_process[dim] //= total_shard_size + + args_to_process = tuple(args_to_process) + + if node.op == 'call_method': + new_args = (new_args[0],) + args_to_process + else: + new_args = args_to_process + + node.args = new_args + + def _filter_node_with_shape_args(node): + if node.op == 'call_method': + target = getattr(node.args[0]._meta_data.__class__, node.target) + elif node.op == 'call_function': + target = node.target + else: + target = None + + if target in SHAPE_ARGUMENT_OPS: + return True + return False + for node in nodes: # skip the placeholder node added in _solution_annotation pass if not hasattr(node, 'sharding_spec'): continue - def _process_sharding_spec(sharding_spec): - if isinstance(sharding_spec, ShardingSpec): - dim_partition_dict = sharding_spec.dim_partition_dict - device_mesh = sharding_spec.device_mesh - return dim_partition_dict, device_mesh - if sharding_spec is None: - return None, None - assert isinstance(sharding_spec, - (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' - - device_mesh = sharding_spec[0].device_mesh - dim_partition_dict = [] - for element in sharding_spec: - dim_partition_dict.append(_process_sharding_spec(element)) - return dim_partition_dict, sharding_spec - - output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec) - new_args = [] - - if node.op == 'call_method': - method = getattr(node.args[0]._meta_data.__class__, node.target) - # process the node with (input, *shape) style args - if method in (torch.Tensor.view, torch.Tensor.reshape): - - for arg in node.args: - if isinstance(arg, Node): - if isinstance(arg._meta_data, (int, tuple, list)): - new_args.append(arg._meta_data) - else: - new_args.append(arg) - else: - assert isinstance( - arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.' - new_args.append(arg) - - for dim, shard_dims in output_dim_partition_dict.items(): - total_shard_size = 1 - for shard_dim in shard_dims: - total_shard_size *= device_mesh.shape[shard_dim] - # There are two ways to use torch.view: - # 1. torch.view(input, *shape) - # 2. torch.view(input, shape) - if isinstance(new_args[1], int): - # we will skip the dim with -1 value - if new_args[dim + 1] == -1: - continue - else: - new_args[dim + 1] //= total_shard_size - else: - new_args[1] = list(new_args[1]) - # we will skip the dim with -1 value - if new_args[1][dim] == -1: - continue - else: - new_args[1][dim] //= total_shard_size - node.args = tuple(new_args) - - elif node.op == 'call_function': - target = node.target - # process the node with (input, torch.Size) style args - if target in (torch.reshape,): - for arg in node.args: - if isinstance(arg, Node): - if isinstance(arg._meta_data, (tuple, list)): - new_args.append(list(arg._meta_data)) - else: - new_args.append(arg) - else: - assert isinstance( - arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.' - new_args.append(list(arg)) - - for dim, shard_dims in output_dim_partition_dict.items(): - # we will skip the dim with -1 value - if new_args[1][dim] == -1: - continue - total_shard_size = 1 - for shard_dim in shard_dims: - total_shard_size *= device_mesh.shape[shard_dim] - new_args[1][dim] //= total_shard_size - node.args = tuple(new_args) + output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec) + if _filter_node_with_shape_args(node): + _scale_args_adapt_sharding_spec(output_dim_partition_dict, device_mesh, node) return gm -def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False): +def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False): """ Apply the sharding action to the module parameters and buffers following the instructions of solver solution. @@ -361,6 +386,49 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o nodes = tuple(mod_graph.nodes) # This stream is created for overlaping the communication and computation. reduction_stream = torch.cuda.Stream() + + def _add_hook_for_grad_communication(node, param): + + comm_actions = node.best_strategy.communication_actions + + def _filter_param_to_hook(node, op_data, comm_action): + if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK: + return True + if node.op == 'get_attr' and isinstance( + node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: + return True + return False + + for operation_data, comm_action in comm_actions.items(): + comm_spec_to_use = comm_action.comm_spec + # register hook to the parameters + if _filter_param_to_hook(node, operation_data, comm_action): + + def wrapper(param, comm_spec, stream, overlap): + + def hook_fn(grad): + if overlap: + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) + else: + _all_reduce(grad, comm_spec, async_op=False) + + param.register_hook(hook_fn) + + wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap) + + def _shard_param(param, target_sharding_spec): + # apply the sharding spec of parameters + if target_sharding_spec.dim_partition_dict != {}: + origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) + setattr(param, 'sharding_spec', origin_sharding_spec) + # TODO: build a ColoParamter class to manager the distributed parameters + # we could use .data here, because all the operations just happen before the real training + # loop, so we don't need to track these operations in the autograd graph. + param = torch.nn.Parameter( + shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, + target_sharding_spec).detach().clone()) + for node in nodes: if node.op == 'call_module': target_module = node.graph.owning_module.get_submodule(node.target) @@ -370,36 +438,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o setattr(target_module, 'processed', True) for name, param in target_module.named_parameters(): target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) - # apply the sharding spec of parameters - if target_sharding_spec.dim_partition_dict != {}: - origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) - setattr(param, 'sharding_spec', origin_sharding_spec) - # TODO: build a ColoParamter class to manager the distributed parameters - # we could use .data here, because all the operations just happen before the real training - # loop, so we don't need to track these operations in the autograd graph. - param = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, - target_sharding_spec).detach().clone()) + _shard_param(param, target_sharding_spec) setattr(target_module, name, param) - comm_actions = node.best_strategy.communication_actions - for operation_data, comm_action in comm_actions.items(): - comm_spec_to_use = comm_action.comm_spec - # register hook to the parameters - if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: - - def wrapper(param, comm_spec, stream, overlap): - - def hook_fn(grad): - if overlap: - with torch.cuda.stream(stream): - _all_reduce(grad, comm_spec, async_op=True) - else: - _all_reduce(grad, comm_spec, async_op=False) - - param.register_hook(hook_fn) - - wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap) + _add_hook_for_grad_communication(node, param) sharded_buffer_dict = {} # apply the sharding spec of buffers @@ -427,37 +469,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o target = getattr(target_module, atoms[-1]) target_sharding_spec = node.sharding_spec - if target_sharding_spec.dim_partition_dict != {}: - origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {}) - setattr(target, 'sharding_spec', origin_sharding_spec) - # TODO: build a ColoParamter class to manager the distributed parameters - # we could use .data here, because all the operations just happen before the real training - # loop, so we don't need to track these operations in the autograd graph. - target = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec, - target_sharding_spec).detach().clone()) + _shard_param(target, target_sharding_spec) assert hasattr(target_module, atoms[-1]) setattr(target_module, atoms[-1], target) + _add_hook_for_grad_communication(node, target) - comm_actions = node.best_strategy.communication_actions - for operation_data, comm_action in comm_actions.items(): - comm_spec_to_use = comm_action.comm_spec - # register hook to the parameters - if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: - - def wrapper(param, comm_spec, stream, overlap): - - def hook_fn(grad): - if overlap: - with torch.cuda.stream(stream): - _all_reduce(grad, comm_spec, async_op=True) - else: - _all_reduce(grad, comm_spec, async_op=False) - - param.register_hook(hook_fn) - - wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap) return gm @@ -471,14 +488,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule): def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh, - strategies_constructor: StrategiesConstructor = None, + strategies_constructor: StrategiesConstructor, overlap=False): - gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation( + gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass( gm, solution, strategies_constructor) - gm = _size_value_converting(gm, device_mesh) - gm = _node_args_converting(gm, device_mesh) + gm = size_value_converting_pass(gm, device_mesh) + gm = node_args_converting_pass(gm, device_mesh) # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. # gm = implicit_comm_action_apply(gm) - gm = _module_params_sharding(gm, device_mesh, overlap=overlap) + gm = module_params_sharding_pass(gm, device_mesh, overlap=overlap) return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict diff --git a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py new file mode 100644 index 000000000..d0d107610 --- /dev/null +++ b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py @@ -0,0 +1,54 @@ +import torch +import torch.nn.functional as F + +from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec + + +class TestModule(torch.nn.Module): + + def forward(self, x): + x = x.view(4, 4, 2) + return x + + +def insert_narrow(gm, x_node): + graph = gm.graph + with graph.inserting_after(x_node): + shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + view_node = list(x_node.users.keys())[0] + new_args = list(view_node.args) + new_args[0] = shard_node + view_node.args = tuple(new_args) + return gm + + +def test_node_args_converting_pass(): + model = TestModule() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + meta_args = {'x': torch.rand(4, 8).to('meta')} + input = torch.rand(4, 8) + tracer = ColoTracer() + graph = tracer.trace(root=model, meta_args=meta_args) + + x_node = list(graph.nodes)[0] + view_node = list(graph.nodes)[1] + sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) + setattr(x_node, 'sharding_spec', sharding_spec) + setattr(view_node, 'sharding_spec', sharding_spec) + + gm = ColoGraphModule(model, graph) + gm = node_args_converting_pass(gm, device_mesh) + gm = insert_narrow(gm, x_node) + gm.recompile() + output = gm(input) + assert output.shape == torch.Size([2, 4, 2]) + + +if __name__ == '__main__': + test_node_args_converting_pass() diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py new file mode 100644 index 000000000..349483008 --- /dev/null +++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py @@ -0,0 +1,65 @@ +import torch +import torch.nn.functional as F + +from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec + + +class TestModule(torch.nn.Module): + + def forward(self, x): + size = x.size() + return size + + +def insert_narrow(gm, x_node): + graph = gm.graph + with graph.inserting_after(x_node): + shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + size_node = list(x_node.users.keys())[0] + size_node.args = (shard_node,) + return gm + + +def recover_narrow(gm, narrow_node): + graph = gm.graph + size_node = list(graph.nodes)[2] + x_node = narrow_node.args[0] + size_node.args = (x_node,) + graph.erase_node(narrow_node) + return gm + + +def test_size_value_converting_pass(): + model = TestModule() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + meta_args = {'x': torch.rand(4, 8).to('meta')} + input = torch.rand(4, 8) + tracer = ColoTracer() + graph = tracer.trace(root=model, meta_args=meta_args) + + x_node = list(graph.nodes)[0] + x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) + setattr(x_node, 'sharding_spec', x_sharding_spec) + gm = ColoGraphModule(model, graph) + gm = insert_narrow(gm, x_node) + gm.recompile() + size = gm(input) + assert size == torch.Size([2, 8]) + + narrow_node = list(gm.graph.nodes)[1] + gm = recover_narrow(gm, narrow_node) + gm = size_value_converting_pass(gm, device_mesh) + gm = insert_narrow(gm, x_node) + gm.recompile() + size = gm(input) + assert size == torch.Size([4, 8]) + + +if __name__ == '__main__': + test_size_value_converting_pass() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index 3d268ea43..18afacf56 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -1,12 +1,9 @@ -from faulthandler import disable from functools import partial -from xml.dom import WrongDocumentErr import pytest import torch import torch.multiprocessing as mp import torch.nn as nn -from typing_extensions import Self from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( From f6b4ca4e6cc2e7822e38cdc61da0566aee129828 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 15 Feb 2023 10:53:54 +0800 Subject: [PATCH 19/30] [devops] add chatgpt ci (#2713) --- .github/workflows/run_chatgpt_examples.yml | 41 ++++++++++++++++++++ .github/workflows/run_chatgpt_unit_tests.yml | 41 ++++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 .github/workflows/run_chatgpt_examples.yml create mode 100644 .github/workflows/run_chatgpt_unit_tests.yml diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml new file mode 100644 index 000000000..9d7c1ff99 --- /dev/null +++ b/.github/workflows/run_chatgpt_examples.yml @@ -0,0 +1,41 @@ +name: Run ChatGPT examples + +on: + pull_request: + types: [synchronize, opened, reopened] + paths: + - 'applications/ChatGPT/chatgpt/**' + - 'applications/ChatGPT/requirements.txt' + - 'applications/ChatGPT/setup.py' + - 'applications/ChatGPT/examples/**' + + +jobs: + tests: + name: Run ChatGPT examples + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt + timeout-minutes: 30 + defaults: + run: + shell: bash + steps: + - name: Checkout ColossalAI + uses: actions/checkout@v2 + + - name: Install ColossalAI and ChatGPT + run: | + pip install -v . + cd applications/ChatGPT + pip install -v . + pip install -r examples/requirements.txt + + - name: Execute Examples + run: | + ./examples/test_ci.sh + env: + NCCL_SHM_DISABLE: 1 + MAX_JOBS: 8 + PROMPT_PATH: /data/scratch/chatgpt/prompts.csv diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml new file mode 100644 index 000000000..3ac0d2d8c --- /dev/null +++ b/.github/workflows/run_chatgpt_unit_tests.yml @@ -0,0 +1,41 @@ +name: Run ChatGPT unit tests + +on: + pull_request: + types: [synchronize, opened, reopened] + paths: + - 'applications/ChatGPT/chatgpt/**' + - 'applications/ChatGPT/requirements.txt' + - 'applications/ChatGPT/setup.py' + - 'applications/ChatGPT/requirements-test.txt' + - 'applications/ChatGPT/tests/**' + - 'applications/ChatGPT/pytest.ini' + +jobs: + tests: + name: Run ChatGPT unit tests + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt + timeout-minutes: 30 + defaults: + run: + shell: bash + steps: + - name: Checkout ColossalAI + uses: actions/checkout@v2 + + - name: Install ColossalAI and ChatGPT + run: | + pip install -v . + cd applications/ChatGPT + pip install -v . + pip install -r requirements-test.txt + + - name: Execute Unit Testing + run: | + pytest tests/ + env: + NCCL_SHM_DISABLE: 1 + MAX_JOBS: 8 From d4d3387f452a26720506ac75cca4e754987eb748 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Wed, 15 Feb 2023 11:08:35 +0800 Subject: [PATCH 20/30] [doc] add open-source contribution invitation (#2714) * [doc] fix typo * [doc] add invitation --- README.md | 2 +- applications/ChatGPT/README.md | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 20a5f2606..c2ad6ffc7 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/) - Colossal-AI: Make big AI models cheaper, easier, and scalable + Colossal-AI: Making big AI models cheaper, easier, and scalable

    Paper | Documentation | diff --git a/applications/ChatGPT/README.md b/applications/ChatGPT/README.md index 43085f3ab..b3ea239a9 100644 --- a/applications/ChatGPT/README.md +++ b/applications/ChatGPT/README.md @@ -60,6 +60,19 @@ We also support training reward model with true-world data. See `examples/train_ - [ ] integrate with Ray - [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL) +## Invitation to open-source contribution +Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build an ecosystem with Colossal-AI, making efforts towards the era of big AI models from the starting point of replicating ChatGPT! + +You may contact us or participate in the following ways: +1. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) or submitting a [PR](https://github.com/hpcaitech/ColossalAI/pulls) on GitHub +2. Join the Colossal-AI community on +[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +and [WeChat](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. +3. Check out and fill in the [cooperation proposal](https://www.hpc-ai.tech/partners) +4. Send your proposal to email contact@hpcaitech.com + +Thanks so much to all of our amazing contributors! + ## Quick Preview

    From 2045d45ab73d6f0458964f38edd92b4637d34556 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 15 Feb 2023 11:24:18 +0800 Subject: [PATCH 21/30] [doc] updated documentation version list (#2715) --- .github/workflows/doc_build_after_merge.yml | 2 +- docs/versions.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/doc_build_after_merge.yml b/.github/workflows/doc_build_after_merge.yml index dae3b70e1..2f7b708ab 100644 --- a/.github/workflows/doc_build_after_merge.yml +++ b/.github/workflows/doc_build_after_merge.yml @@ -5,6 +5,7 @@ on: pull_request: paths: - 'version.txt' + - 'docs/' types: - closed @@ -16,7 +17,6 @@ jobs: steps: - name: trigger workflow in ColossalAI-Documentation run: | - gh curl \ -X POST \ -H "Accept: application/vnd.github+json" \ diff --git a/docs/versions.json b/docs/versions.json index dde32982b..49a0fab2b 100644 --- a/docs/versions.json +++ b/docs/versions.json @@ -1,3 +1,3 @@ [ - "current" + "v0.2.4" ] From 5b24987fa75adee654aac0b02c8805fb8042cc05 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 15 Feb 2023 12:25:50 +0800 Subject: [PATCH 22/30] [autoparallel] fix parameters sharding bug (#2716) --- .../auto_parallel/passes/runtime_preparation_pass.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index bb419be35..e63bfdfe7 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -426,8 +426,9 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes # we could use .data here, because all the operations just happen before the real training # loop, so we don't need to track these operations in the autograd graph. param = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, - target_sharding_spec).detach().clone()) + shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, + target_sharding_spec).detach().clone()) + return param for node in nodes: if node.op == 'call_module': @@ -438,7 +439,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes setattr(target_module, 'processed', True) for name, param in target_module.named_parameters(): target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) - _shard_param(param, target_sharding_spec) + param = _shard_param(param, target_sharding_spec) setattr(target_module, name, param) _add_hook_for_grad_communication(node, param) @@ -469,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes target = getattr(target_module, atoms[-1]) target_sharding_spec = node.sharding_spec - _shard_param(target, target_sharding_spec) + target = _shard_param(target, target_sharding_spec) assert hasattr(target_module, atoms[-1]) setattr(target_module, atoms[-1], target) From 21d6a48f4d9a4c3880110aecde9015fbe303ce9f Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 15 Feb 2023 13:48:28 +0800 Subject: [PATCH 23/30] [autoparallel] add shard option (#2696) * [autoparallel] add shard option * polish --- .../auto_parallel/tensor_shard/initialize.py | 70 ++++++++++++++++--- .../tensor_shard/node_handler/__init__.py | 5 +- .../tensor_shard/node_handler/node_handler.py | 31 +++++--- .../tensor_shard/node_handler/option.py | 17 ----- .../auto_parallel/tensor_shard/options.py | 49 +++++++++++++ .../tensor_shard/solver/__init__.py | 3 +- .../tensor_shard/solver/options.py | 30 -------- .../tensor_shard/solver/solver.py | 2 +- .../solver/strategies_constructor.py | 26 +++++-- .../test_gpt/test_solver_with_gpt_module.py | 9 +-- .../test_tensor_shard/test_metainfo/utils.py | 3 +- .../test_node_handler/test_shard_option.py | 14 +++- .../test_node_handler/utils.py | 3 +- .../test_param_resharding_cost.py | 9 +-- .../test_solver_with_resnet_v2.py | 9 +-- 15 files changed, 176 insertions(+), 104 deletions(-) delete mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/option.py create mode 100644 colossalai/auto_parallel/tensor_shard/options.py delete mode 100644 colossalai/auto_parallel/tensor_shard/solver/options.py diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 23ed0f433..012b0ff43 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -8,14 +8,9 @@ from torch.fx.graph import Graph from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction -from colossalai.auto_parallel.tensor_shard.solver import ( - CostGraph, - GraphAnalyser, - Solver, - SolverOptions, - StrategiesConstructor, -) +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.graph_module import ColoGraphModule @@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f pass -def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh): +def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, + shard_option: str): ''' This method is used to build the strategy_constructor for the given graph. After this method, each node in the graph will have a strategies_vector which is constructed by the related node handler. ''' - solver_options = SolverOptions() + if solver_preference == 'standard': + solver_preference = SolverPerference.STANDARD + elif solver_preference == 'tp': + solver_preference = SolverPerference.TP + elif solver_preference == 'dp': + solver_preference = SolverPerference.DP + else: + raise ValueError(f'Invalid solver_preference: {solver_preference}') + + if dataloader_option == 'replicated': + dataloader_option = DataloaderOption.REPLICATED + elif dataloader_option == 'distributed': + dataloader_option = DataloaderOption.DISTRIBUTED + else: + raise ValueError(f'Invalid dataloader_option: {dataloader_option}') + + if shard_option == 'standard': + shard_option = ShardOption.STANDARD + elif shard_option == 'shard': + shard_option = ShardOption.SHARD + elif shard_option == 'shard_last_axis': + shard_option = ShardOption.SHARD_LAST_AXIS + elif shard_option == 'full_shard': + shard_option = ShardOption.FULL_SHARD + else: + raise ValueError(f'Invalid shard_option: {shard_option}') + + solver_options = SolverOptions(solver_perference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() @@ -183,6 +208,9 @@ def initialize_model(model: nn.Module, device_mesh: DeviceMesh, memory_budget: float = -1.0, overlap: bool = False, + solver_preference: str = 'standard', + dataloader_option: str = 'replicated', + shard_option: str = 'standard', save_solver_solution: bool = False, load_solver_solution: bool = False, solution_path: str = None, @@ -198,6 +226,12 @@ def initialize_model(model: nn.Module, the memory budget will be infinity. overlap(optional): the overlap is used to specify whether to overlap gradient communication and backward computing. + solver_preference(optional): the solver_preference is used to specify which parallelism algorithm + has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'. + dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will + be used. The valid dataloader_option could be 'replicated' or 'distributed'. + shard_option(optional): the shard_option is used to specify how many axes will be used to shard the + model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'. save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved to the solution_path. load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded @@ -212,7 +246,12 @@ def initialize_model(model: nn.Module, graph = tracer.trace(root=model, meta_args=meta_args) gm = ColoGraphModule(model, graph, model.__class__.__name__) gm.recompile() - strategies_constructor = build_strategy_constructor(graph, device_mesh) + + strategies_constructor = build_strategy_constructor(graph, + device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option) if load_solver_solution: solution = torch.load(solution_path) else: @@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module, alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, logical_mesh_shape: Tuple[int] = None, logical_mesh_id: torch.Tensor = None, + solver_preference: str = 'standard', + dataloader_option: str = 'replicated', + shard_option: str = 'standard', save_solver_solution: bool = False, load_solver_solution: bool = False, solver_solution_path: str = None, @@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module, mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be generated by search_best_logical_mesh_shape function. logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. + solver_preference(optional): the solver_preference is used to specify which parallelism algorithm + has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'. + dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will + be used. The valid dataloader_option could be 'replicated' or 'distributed'. + shard_option(optional): the shard_option is used to specify how many axes will be used to shard the + model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'. save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved to the solution_path. load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded @@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module, rst_to_unpack = initialize_model(model, meta_args, device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, save_solver_solution=save_solver_solution, load_solver_solution=load_solver_solution, solution_path=solver_solution_path, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 0050358ce..9903ca54e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -11,7 +11,6 @@ from .layer_norm_handler import LayerNormModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .matmul_handler import MatMulHandler from .normal_pooling_handler import NormPoolingHandler -from .option import ShardOption from .output_handler import OutputHandler from .permute_handler import PermuteHandler from .placeholder_handler import PlaceholderHandler @@ -31,6 +30,6 @@ __all__ = [ 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', - 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption', - 'TransposeHandler', 'SplitHandler' + 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler', + 'SplitHandler' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index c6f8d035a..136e57c5e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -5,7 +5,7 @@ import torch from torch.fx.node import Node from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register -from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption +from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -32,19 +32,19 @@ class NodeHandler(ABC): strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. ''' - def __init__( - self, - node: Node, - device_mesh: DeviceMesh, - strategies_vector: StrategiesVector, - shard_option: ShardOption = ShardOption.STANDARD, - ) -> None: + def __init__(self, + node: Node, + device_mesh: DeviceMesh, + strategies_vector: StrategiesVector, + shard_option: ShardOption = ShardOption.STANDARD, + solver_perference: SolverPerference = SolverPerference.STANDARD) -> None: self.node = node self.predecessor_node = list(node._input_nodes.keys()) self.successor_node = list(node.users.keys()) self.device_mesh = device_mesh self.strategies_vector = strategies_vector self.shard_option = shard_option + self.solver_perference = solver_perference def update_resharding_cost(self, strategy: ShardingStrategy) -> None: """ @@ -187,15 +187,24 @@ class NodeHandler(ABC): remove_strategy_list = [] for strategy in self.strategies_vector: - shard_level = 0 + shard_axis_list = [] + last_axis = len(self.device_mesh.mesh_shape) - 1 for op_data, sharding_spec in strategy.sharding_specs.items(): if op_data.data is not None and isinstance(op_data.data, torch.Tensor): - for dim, shard_axis in sharding_spec.dim_partition_dict.items(): - shard_level += len(shard_axis) + for dim, shard_axes in sharding_spec.dim_partition_dict.items(): + for shard_axis in shard_axes: + if shard_axis not in shard_axis_list: + shard_axis_list.append(shard_axis) + + shard_level = len(shard_axis_list) + using_last_axis = last_axis in shard_axis_list or -1 in shard_axis_list if self.shard_option == ShardOption.SHARD and shard_level == 0: remove_strategy_list.append(strategy) if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1: remove_strategy_list.append(strategy) + if self.shard_option == ShardOption.SHARD_LAST_AXIS: + if shard_level != 1 or using_last_axis == False: + remove_strategy_list.append(strategy) for strategy in remove_strategy_list: self.strategies_vector.remove(strategy) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/option.py b/colossalai/auto_parallel/tensor_shard/node_handler/option.py deleted file mode 100644 index dffb0386d..000000000 --- a/colossalai/auto_parallel/tensor_shard/node_handler/option.py +++ /dev/null @@ -1,17 +0,0 @@ -from enum import Enum - -__all__ = ['ShardOption'] - - -class ShardOption(Enum): - """ - This enum class is to define the shard level required in node strategies. - - Notes: - STANDARD: We do not add any extra shard requirements. - SHARD: We require the node to be shard using at least one device mesh axis. - FULL_SHARD: We require the node to be shard using all device mesh axes. - """ - STANDARD = 0 - SHARD = 1 - FULL_SHARD = 2 diff --git a/colossalai/auto_parallel/tensor_shard/options.py b/colossalai/auto_parallel/tensor_shard/options.py new file mode 100644 index 000000000..f0ea502a6 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/options.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from enum import Enum + +__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption'] + + +class SolverPerference(Enum): + """ + This enum class is to define the solver preference. + """ + STANDARD = 0 + DP = 1 + TP = 2 + + +class ShardOption(Enum): + """ + This enum class is to define the shard level required in node strategies. + + Notes: + STANDARD: We do not add any extra shard requirements. + SHARD: We require the node to be shard using at least one device mesh axis. + SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis. + FULL_SHARD: We require the node to be shard using all device mesh axes. + TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis. + TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes. + """ + STANDARD = 0 + SHARD = 1 + SHARD_LAST_AXIS = 2 + FULL_SHARD = 3 + + +class DataloaderOption(Enum): + """ + This enum class is to define the dataloader option. + """ + REPLICATED = 0 + DISTRIBUTED = 1 + + +@dataclass +class SolverOptions: + """ + SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. + """ + solver_perference: SolverPerference = SolverPerference.STANDARD + dataloader_option: DataloaderOption = DataloaderOption.REPLICATED + shard_option: ShardOption = ShardOption.STANDARD diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py index e9f9ba881..f9e6bd923 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py @@ -1,7 +1,6 @@ from .cost_graph import CostGraph from .graph_analysis import GraphAnalyser -from .options import SolverOptions from .solver import Solver from .strategies_constructor import StrategiesConstructor -__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions'] +__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph'] diff --git a/colossalai/auto_parallel/tensor_shard/solver/options.py b/colossalai/auto_parallel/tensor_shard/solver/options.py deleted file mode 100644 index b52e55708..000000000 --- a/colossalai/auto_parallel/tensor_shard/solver/options.py +++ /dev/null @@ -1,30 +0,0 @@ -from dataclasses import dataclass -from enum import Enum - -__all__ = ['SolverOptions'] - - -class SolverPerference(Enum): - """ - This enum class is to define the solver preference. - """ - STANDARD = 0 - DP = 1 - TP = 2 - - -class DataloaderOption(Enum): - """ - This enum class is to define the dataloader option. - """ - REPLICATED = 0 - DISTRIBUTED = 1 - - -@dataclass -class SolverOptions: - """ - SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. - """ - solver_perference: SolverPerference = SolverPerference.STANDARD - dataloader_option: DataloaderOption = DataloaderOption.REPLICATED diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py index 89d0da223..3bc3e8960 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/solver.py +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -33,7 +33,7 @@ class Solver: solution_numbers: int = 1, forward_only: bool = False, memory_increasing_coefficient: float = 1.3, - verbose=True): + verbose=False): ''' Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. Argument: diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 042b9bb4b..40741daca 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVe from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec from colossalai.device.device_mesh import DeviceMesh -from .options import DataloaderOption, SolverOptions +from ..options import DataloaderOption, SolverOptions __all__ = ['StrategiesConstructor'] @@ -101,7 +101,11 @@ class StrategiesConstructor: # get_attr node elif node.op == 'get_attr': - getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector) + getattr_handler = GetattrHandler(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) getattr_handler.register_strategy() # call_module node @@ -109,7 +113,11 @@ class StrategiesConstructor: target = node.target submod = self.root_module.get_submodule(target) submod_type = type(submod) - handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector) + handler = operator_registry.get(submod_type)(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) handler.register_strategy() # attach metainfo_vector to node if hasattr(handler, 'metainfo_vector'): @@ -118,7 +126,11 @@ class StrategiesConstructor: # call_function node elif node.op == 'call_function': target = node.target - handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector) + handler = operator_registry.get(target)(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) handler.register_strategy() # attach metainfo_vector to node if hasattr(handler, 'metainfo_vector'): @@ -127,7 +139,11 @@ class StrategiesConstructor: # call_method node elif node.op == 'call_method': method = getattr(node.args[0]._meta_data.__class__, node.target) - handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector) + handler = operator_registry.get(method)(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) handler.register_strategy() # attach metainfo_vector to node if hasattr(handler, 'metainfo_vector'): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index 26ad0d3a0..a6be1928b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -4,13 +4,8 @@ import transformers from torch.fx import GraphModule from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP -from colossalai.auto_parallel.tensor_shard.solver import ( - CostGraph, - GraphAnalyser, - Solver, - SolverOptions, - StrategiesConstructor, -) +from colossalai.auto_parallel.tensor_shard.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py index b8c01d358..60ecd1dd9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -7,8 +7,9 @@ from torch.fx import GraphModule from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem -from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py index fda041110..f6895d92a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py @@ -5,7 +5,7 @@ import torch.multiprocessing as mp import torch.nn as nn from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler -from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption +from colossalai.auto_parallel.tensor_shard.options import ShardOption from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer @@ -49,6 +49,15 @@ def check_shard_option(shard_option): strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] + if shard_option == ShardOption.SHARD_LAST_AXIS: + # RR = RS x SR + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS1 = RR x RS1' in strategy_name_list + + return + # SS = SR x RS assert 'S1S0 = S1R x RS0_0' in strategy_name_list assert 'S0S1 = S0R x RS1_1' in strategy_name_list @@ -104,7 +113,8 @@ def check_shard_option(shard_option): @run_on_environment_flag(name='AUTO_PARALLEL') def test_shard_option(): - for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD]: + # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]: + for shard_option in [ShardOption.SHARD_LAST_AXIS]: check_shard_option(shard_option) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index db76ed9b8..14c8cb296 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -6,7 +6,8 @@ from torch.fx import GraphModule from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass -from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser from colossalai.auto_parallel.tensor_shard.solver.solver import Solver diff --git a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py index b504d59c9..92f011ba3 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py @@ -1,13 +1,8 @@ import torch +from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType -from colossalai.auto_parallel.tensor_shard.solver import ( - CostGraph, - GraphAnalyser, - Solver, - SolverOptions, - StrategiesConstructor, -) +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.testing.pytest_wrapper import run_on_environment_flag diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index f4a5ae7ac..6f64acd52 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -3,13 +3,8 @@ from torch.fx import GraphModule from torchvision.models import resnet50 from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP -from colossalai.auto_parallel.tensor_shard.solver import ( - CostGraph, - GraphAnalyser, - Solver, - SolverOptions, - StrategiesConstructor, -) +from colossalai.auto_parallel.tensor_shard.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager From 9c0943ecdbd0a1f489de94c22e202cd3ebf8efb0 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 15 Feb 2023 13:59:58 +0800 Subject: [PATCH 24/30] [chatgpt] optimize generation kwargs (#2717) * [chatgpt] ppo trainer use default generate args * [chatgpt] example remove generation preparing fn * [chatgpt] benchmark remove generation preparing fn * [chatgpt] fix ci --- .github/workflows/run_chatgpt_examples.yml | 1 + .github/workflows/run_chatgpt_unit_tests.yml | 1 + .../ChatGPT/benchmarks/benchmark_gpt_dummy.py | 3 -- .../benchmarks/benchmark_opt_lora_dummy.py | 3 -- applications/ChatGPT/chatgpt/trainer/ppo.py | 10 +++++ applications/ChatGPT/examples/train_dummy.py | 45 ++++++++----------- .../ChatGPT/examples/train_prompts.py | 37 ++++++++------- 7 files changed, 48 insertions(+), 52 deletions(-) diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index 9d7c1ff99..af59c8db2 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -34,6 +34,7 @@ jobs: - name: Execute Examples run: | + cd applications/ChatGPT ./examples/test_ci.sh env: NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml index 3ac0d2d8c..8dcf21fe2 100644 --- a/.github/workflows/run_chatgpt_unit_tests.yml +++ b/.github/workflows/run_chatgpt_unit_tests.yml @@ -35,6 +35,7 @@ jobs: - name: Execute Unit Testing run: | + cd applications/ChatGPT pytest tests/ env: NCCL_SHM_DISABLE: 1 diff --git a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py index 8474f3ba7..3e66e4e7a 100644 --- a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py +++ b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py @@ -5,7 +5,6 @@ import torch import torch.distributed as dist import torch.nn as nn from chatgpt.nn import GPTActor, GPTCritic, RewardModel -from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn from chatgpt.trainer import PPOTrainer from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy @@ -151,8 +150,6 @@ def main(args): top_k=50, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, - prepare_inputs_fn=gpt_prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn, callbacks=[performance_evaluator]) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) diff --git a/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py index accbc4155..8cee5489e 100644 --- a/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py @@ -5,7 +5,6 @@ import torch import torch.distributed as dist import torch.nn as nn from chatgpt.nn import OPTActor, OPTCritic, RewardModel -from chatgpt.nn.generation_utils import opt_prepare_inputs_fn, update_model_kwargs_fn from chatgpt.trainer import PPOTrainer from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy @@ -144,8 +143,6 @@ def main(args): top_k=50, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, - prepare_inputs_fn=opt_prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn, callbacks=[performance_evaluator]) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) diff --git a/applications/ChatGPT/chatgpt/trainer/ppo.py b/applications/ChatGPT/chatgpt/trainer/ppo.py index 85beb223e..b1d11b224 100644 --- a/applications/ChatGPT/chatgpt/trainer/ppo.py +++ b/applications/ChatGPT/chatgpt/trainer/ppo.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional import torch.nn as nn from chatgpt.experience_maker import Experience, NaiveExperienceMaker from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss +from chatgpt.nn.generation_utils import update_model_kwargs_fn from chatgpt.replay_buffer import NaiveReplayBuffer from torch.optim import Optimizer @@ -59,6 +60,7 @@ class PPOTrainer(Trainer): dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], **generate_kwargs) -> None: + self._set_default_generate_kwargs(generate_kwargs, actor) actor = Actor(strategy.setup_model(actor.model)) critic = strategy.setup_model(critic) reward_model = strategy.setup_model(reward_model) @@ -102,3 +104,11 @@ class PPOTrainer(Trainer): self.critic_optim.zero_grad() return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} + + def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None: + # use huggingface models method directly + if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'): + generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation + + if 'update_model_kwargs_fn' not in generate_kwargs: + generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn diff --git a/applications/ChatGPT/examples/train_dummy.py b/applications/ChatGPT/examples/train_dummy.py index 313be2c3b..a14117ed5 100644 --- a/applications/ChatGPT/examples/train_dummy.py +++ b/applications/ChatGPT/examples/train_dummy.py @@ -3,12 +3,6 @@ from copy import deepcopy import torch from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel -from chatgpt.nn.generation_utils import ( - bloom_prepare_inputs_fn, - gpt_prepare_inputs_fn, - opt_prepare_inputs_fn, - update_model_kwargs_fn, -) from chatgpt.trainer import PPOTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from torch.optim import Adam @@ -66,36 +60,33 @@ def main(args): if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token - prepare_inputs_fn = gpt_prepare_inputs_fn elif args.model == 'bloom': tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) tokenizer.pad_token = tokenizer.eos_token - prepare_inputs_fn = bloom_prepare_inputs_fn elif args.model == 'opt': tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - prepare_inputs_fn = opt_prepare_inputs_fn else: raise ValueError(f'Unsupported model "{args.model}"') # configure trainer - trainer = PPOTrainer(strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - max_epochs=args.max_epochs, - train_batch_size=args.train_batch_size, - tokenizer=preprocess_batch, - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - prepare_inputs_fn=prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn) + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + tokenizer=preprocess_batch, + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device()) trainer.fit(random_prompts, diff --git a/applications/ChatGPT/examples/train_prompts.py b/applications/ChatGPT/examples/train_prompts.py index 994b10fe0..cf351b91a 100644 --- a/applications/ChatGPT/examples/train_prompts.py +++ b/applications/ChatGPT/examples/train_prompts.py @@ -3,7 +3,6 @@ from copy import deepcopy import pandas as pd from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel -from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn from chatgpt.trainer import PPOTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from torch.optim import Adam @@ -70,24 +69,24 @@ def main(args): return {k: v.cuda() for k, v in batch.items()} # configure trainer - trainer = PPOTrainer(strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - max_epochs=args.max_epochs, - train_batch_size=args.train_batch_size, - tokenizer=tokenize_fn, - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - prepare_inputs_fn=gpt_prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn) + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + tokenizer=tokenize_fn, + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) trainer.fit(dataset, num_episodes=args.num_episodes, From 7aacfad8aff2fe3654aa3f0204e5dc6fad813ed3 Mon Sep 17 00:00:00 2001 From: "CH.Li" <32587096+lich99@users.noreply.github.com> Date: Wed, 15 Feb 2023 14:54:53 +0800 Subject: [PATCH 25/30] fix typo (#2721) --- applications/ChatGPT/benchmarks/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ChatGPT/benchmarks/README.md b/applications/ChatGPT/benchmarks/README.md index f7212fc89..b4e28ba1d 100644 --- a/applications/ChatGPT/benchmarks/README.md +++ b/applications/ChatGPT/benchmarks/README.md @@ -37,7 +37,7 @@ We only support `torchrun` to launch now. E.g. ```shell # run GPT2-S on single-node single-GPU with min batch size -torchrun --standalone --nproc_pero_node 1 benchmark_gpt_dummy.py --model s --strategy ddp --experience_batch_size 1 --train_batch_size 1 +torchrun --standalone --nproc_per_node 1 benchmark_gpt_dummy.py --model s --strategy ddp --experience_batch_size 1 --train_batch_size 1 # run GPT2-XL on single-node 4-GPU torchrun --standalone --nproc_per_node 4 benchmark_gpt_dummy.py --model xl --strategy colossalai_zero2 # run GPT3 on 8-node 8-GPU @@ -84,7 +84,7 @@ We only support `torchrun` to launch now. E.g. ```shell # run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size -torchrun --standalone --nproc_pero_node 1 benchmark_opt_lora_dummy.py --model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0 +torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0 # run OPT-350M with lora_rank=4 on single-node 4-GPU torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 350m --strategy colossalai_zero2 --lora_rank 4 ``` From c5be83afbf8d64c9966d802504e4619f3c3fc4a9 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Wed, 15 Feb 2023 16:48:08 +0800 Subject: [PATCH 26/30] Update version.txt (#2727) --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index abd410582..3a4036fb4 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.4 +0.2.5 From 5479fdd5b86a809e2dad20b3279abd1d58816a44 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 15 Feb 2023 17:39:50 +0800 Subject: [PATCH 27/30] [doc] updated documentation version list (#2730) --- docs/versions.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/versions.json b/docs/versions.json index 49a0fab2b..6dd417a18 100644 --- a/docs/versions.json +++ b/docs/versions.json @@ -1,3 +1,3 @@ [ - "v0.2.4" + "v0.2.5" ] From 43dffdaba58ffc9f2de131ea80c91e99d5dfe756 Mon Sep 17 00:00:00 2001 From: cloudhuang Date: Wed, 15 Feb 2023 22:24:45 +0800 Subject: [PATCH 28/30] [doc] fixed a typo in GPT readme (#2736) --- examples/language/gpt/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md index 3d5ce7c88..fe7b23beb 100644 --- a/examples/language/gpt/README.md +++ b/examples/language/gpt/README.md @@ -36,7 +36,7 @@ If you want to test ZeRO1 and ZeRO2 in Colossal-AI, you need to ensure Colossal- ## Dataset -For simplicity, the input data is randonly generated here. +For simplicity, the input data is randomly generated here. ## Training We provide two stable solutions. From ae86a29e2379314da3cb4abf95b5306db6156794 Mon Sep 17 00:00:00 2001 From: YH <100389977+yhna940@users.noreply.github.com> Date: Wed, 15 Feb 2023 23:27:58 +0900 Subject: [PATCH 29/30] Refact method of grad store (#2687) --- .../bookkeeping/gradient_store.py | 23 ++++++++++++++++--- .../zero/sharded_optim/low_level_optim.py | 18 +++++++++------ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py index 8a9128a18..b166752cc 100644 --- a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py +++ b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py @@ -6,7 +6,6 @@ from .base_store import BaseStore class GradientStore(BaseStore): - def __init__(self, *args): super().__init__(*args) # bookkeeping data structures @@ -15,7 +14,7 @@ class GradientStore(BaseStore): # for backward reduction hooks self._grad_acc_objs = [] - def add_accumulate_grad_object(self, obj): + def append_accumulate_grad_object(self, obj): """ Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not be attached successfully. @@ -36,10 +35,12 @@ class GradientStore(BaseStore): :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. :rtype: List[torch.Tensor] """ + if group_id not in self._averaged_gradients: + self._averaged_gradients[group_id] = [] return self._averaged_gradients[group_id] - def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: + def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: """ Append an average gradient to the list of averaged gradients of a parameter group @@ -55,6 +56,22 @@ class GradientStore(BaseStore): else: self._averaged_gradients[group_id] = [tensor] + def add_average_gradient_by_group( + self, group_id: int, tensor_idx: int, tensor: Tensor + ) -> None: + """ + Add an average gradient to the list of averaged gradients of a parameter group + + :param group_id: The index of a parameter group + :param tensor_idx: The index of a tensor in the list of averaged gradients + :param tensor: A :class:`torch.Tensor` object + :type group_id: int + :type tensor_idx: int + :type tensor: torch.Tensor + + """ + self._averaged_gradients[group_id][tensor_idx].add_(tensor) + def reset_average_gradients_by_group(self, group_id: int) -> None: """ Reset the bookkeeping data structure for averaged gradients to an empty list diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py index 89f5f9fad..f5e03ce28 100644 --- a/colossalai/zero/sharded_optim/low_level_optim.py +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -550,20 +550,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): reduction_states[tensor] = False # accumulate gradient - avg_gradients = self._grad_store._averaged_gradients for group_id in range(self.num_param_groups): param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id) - if group_id not in avg_gradients: - avg_gradients[group_id] = [] + avg_gradients_group = self._grad_store.get_averaged_gradients_by_group( + group_id + ) param_idx = 0 for param in param_group: if param.grad is not None: - if len(avg_gradients[group_id]) == param_idx: - avg_gradients[group_id].append(param.grad) + if len(avg_gradients_group) == param_idx: + self._grad_store.append_average_gradient_by_group( + group_id, param.grad + ) else: - avg_gradients[group_id][param_idx].add_(param.grad) + self._grad_store.add_average_gradient_by_group( + group_id, param_idx, param.grad + ) param_idx += 1 # the gradients needed are stored in the avg_gradients buffer @@ -590,4 +594,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # only need to reduce the gradients # left in the communication bucket for reduce_rank in range(self._world_size): - self._run_reduction(reduce_rank) + self._run_reduction(reduce_rank) \ No newline at end of file From 1dc003c1698730234ca6a10248d1d3b800fc9ad9 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 15 Feb 2023 22:28:28 +0800 Subject: [PATCH 30/30] [autoparallel] distinguish different parallel strategies (#2699) --- .../node_handler/linear_handler.py | 5 +- .../strategy/matmul_strategy_generator.py | 58 ++++++-- .../test_gpt/test_runtime_with_gpt_modules.py | 2 +- .../test_permute_and_transpose_handler.py | 138 +++++++++--------- .../test_node_handler/test_softmax_handler.py | 88 +++++------ .../test_node_handler/test_split_handler.py | 88 +++++------ .../test_node_handler/test_view_handler.py | 95 ++++++------ 7 files changed, 255 insertions(+), 219 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index 37ff3c3ab..59091dab5 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -152,7 +152,10 @@ class LinearModuleHandler(MetaInfoModuleHandler): op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + LinearProjectionStrategyGenerator(op_data_mapping, + self.device_mesh, + linear_projection_type='linear', + solver_perference=self.solver_perference)) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index fa2246f95..5d70e131d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -3,6 +3,7 @@ from ast import arg from functools import reduce from typing import List +from colossalai.auto_parallel.tensor_shard.options import SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommType, MemoryCost, @@ -209,9 +210,14 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): - def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'): + def __init__(self, + operation_data_mapping, + device_mesh, + linear_projection_type='linear', + solver_perference=SolverPerference.STANDARD): super().__init__(operation_data_mapping, device_mesh) self.linear_projection_type = linear_projection_type + self.solver_perference = solver_perference def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # C = AB @@ -231,16 +237,22 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): total=fwd_compute_cost + bwd_compute_cost) strategy.compute_cost = compute_cost - def collate_strategies(self) -> List[ShardingStrategy]: + def dp_strategies(self) -> List[ShardingStrategy]: strategies = [] - # SS = SR x RS - strategies.append(self.split_lhs_space_rhs_space(0, 1)) - strategies.append(self.split_lhs_space_rhs_space(1, 0)) + # S01R = S01R x RR + strategies.append(self.split_lhs_1st_dim_1d(0, 1)) - # SR = SS x SR - strategies.append(self.split_lhs_space_both_contract(0, 1)) - strategies.append(self.split_lhs_space_both_contract(1, 0)) + return strategies + + def tp_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + # RR = RS01 x S01R + strategies.append(self.split_lhs_2nd_dim_1d(0, 1)) + + # RS01 = RR x RS01 + strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) # RS = RS x SS strategies.append(self.split_rhs_space_both_contract(0, 1)) @@ -254,20 +266,38 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): strategies.append(self.split_rhs_space_only(0)) strategies.append(self.split_rhs_space_only(1)) - # S01R = S01R x RR - strategies.append(self.split_lhs_1st_dim_1d(0, 1)) + return strategies - # RR = RS01 x S01R - strategies.append(self.split_lhs_2nd_dim_1d(0, 1)) + def mix_strategies(self) -> List[ShardingStrategy]: + strategies = [] - # RS01 = RR x RS01 - strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) + # SS = SR x RS + strategies.append(self.split_lhs_space_rhs_space(0, 1)) + strategies.append(self.split_lhs_space_rhs_space(1, 0)) + + # SR = SS x SR + strategies.append(self.split_lhs_space_both_contract(0, 1)) + strategies.append(self.split_lhs_space_both_contract(1, 0)) # RR = RR x RR strategies.append(self.non_split()) return strategies + def collate_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + if self.solver_perference == SolverPerference.STANDARD: + strategies.extend(self.dp_strategies()) + strategies.extend(self.tp_strategies()) + strategies.extend(self.mix_strategies()) + elif self.solver_perference == SolverPerference.DP: + strategies.extend(self.dp_strategies()) + elif self.solver_perference == SolverPerference.TP: + strategies.extend(self.tp_strategies()) + + return strategies + @ignore_sharding_exception def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): # handle case SS = SR x RS diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index 753ecff53..ebeef9870 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -117,7 +117,7 @@ def check_attention_layer(rank, model_cls, world_size, port): gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - strategies_constructor = build_strategy_constructor(graph, device_mesh) + strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard') solution = solve_solution(gm, strategies_constructor, memory_budget=-1) gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor) gm = ModuleWrapper(gm, *sharding_spec_dicts) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py index b12db1332..af03481d8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -243,79 +243,79 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, if model_cls.__name__ == 'LinearReshapeModel': if reshape_dims == ((0, 2, 1, 3), (1, 2)): - assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list - assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S0, R, S1]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list - assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S1, R, S0]_5' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, S0, R, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, S1, R, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, S01, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S0, R, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S1, R, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, S0, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, S1, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, S01, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list if reshape_dims == (2, 0, 1, 3): - assert '[S0, R, R, S1] -> [R, S0, R, S1]_0' in strategy_name_list - assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [S0, R, R, S1]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [R, S1, R, S0]_3' in strategy_name_list - assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [S1, R, R, S0]_5' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list - assert '[R, R, S0, R] -> [S0, R, R, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list - assert '[R, R, S1, R] -> [S1, R, R, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list - assert '[S01, R, R, R] -> [R, S01, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list - assert '[R, R, S01, R] -> [S01, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[S0, R, R, S1] -> [R, S0, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [S0, R, R, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [R, S1, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [S1, R, R, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [S0, R, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [S1, R, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R] -> [R, S01, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [S01, R, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list if reshape_dims == (1, 3): - assert '[S0, R, R, S1] -> [S0, S1, R, R]_0' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S1, R, S0]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S1, S0, R]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, S0, R, R]_3' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S0, R, S1]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S0, S1, R]_5' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1, R, R]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0, R, R]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1, R, R]_17' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, R, S01]_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, S01, R, R]_22' in strategy_name_list + assert '[S0, R, R, S1] -> [S0, S1, R, R]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S1, R, S0]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S1, S0, R]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, S0, R, R]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S0, R, S1]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S0, S1, R]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, R, S01]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, S01, R, R]_4' in strategy_name_list @run_on_environment_flag(name='AUTO_PARALLEL') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py index b5e8e3277..c43ee292b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -117,54 +117,54 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port): strategy_name_list = [strategy.name for strategy in split_strategies_vector] if softmax_dim == 0: - assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list if softmax_dim == 1: - assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list @run_on_environment_flag(name='AUTO_PARALLEL') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py index 813651869..044aef19d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -198,54 +198,54 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port if model_cls.__name__ == 'LinearSplitModel': if split_dim == 0: - assert '[R, R, R, S1]_0' in strategy_name_list - assert '[R, S0, R, S1]_1' in strategy_name_list - assert '[R, R, S0, S1]_2' in strategy_name_list - assert '[R, R, R, S0]_3' in strategy_name_list - assert '[R, S1, R, S0]_4' in strategy_name_list - assert '[R, R, S1, S0]_5' in strategy_name_list - assert '[R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R]_7' in strategy_name_list - assert '[R, R, S0, R]_8' in strategy_name_list - assert '[R, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R]_10' in strategy_name_list - assert '[R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1]_17' in strategy_name_list - assert '[R, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R]_19' in strategy_name_list - assert '[R, R, S01, R]_20' in strategy_name_list - assert '[R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01]_22' in strategy_name_list + assert '[R, R, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1]_12' in strategy_name_list + assert '[R, R, S0, S1]_13' in strategy_name_list + assert '[R, R, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0]_16' in strategy_name_list + assert '[R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R]_18' in strategy_name_list + assert '[R, R, S0, R]_19' in strategy_name_list + assert '[R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R]_21' in strategy_name_list + assert '[R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1]_5' in strategy_name_list + assert '[R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R]_1' in strategy_name_list + assert '[R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01]_4' in strategy_name_list if split_dim == 1: - assert '[S0, R, R, S1]_0' in strategy_name_list - assert '[R, R, R, S1]_1' in strategy_name_list - assert '[R, R, S0, S1]_2' in strategy_name_list - assert '[S1, R, R, S0]_3' in strategy_name_list - assert '[R, R, R, S0]_4' in strategy_name_list - assert '[R, R, S1, S0]_5' in strategy_name_list - assert '[S0, R, R, R]_6' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, S0, R]_8' in strategy_name_list - assert '[S1, R, R, R]_9' in strategy_name_list - assert '[R, R, R, R]_10' in strategy_name_list - assert '[R, R, S1, R]_11' in strategy_name_list + assert '[S0, R, R, S1]_11' in strategy_name_list assert '[R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1]_17' in strategy_name_list - assert '[S01, R, R, R]_18' in strategy_name_list - assert '[R, R, R, R]_19' in strategy_name_list - assert '[R, R, S01, R]_20' in strategy_name_list + assert '[R, R, S0, S1]_13' in strategy_name_list + assert '[S1, R, R, S0]_14' in strategy_name_list + assert '[R, R, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0]_16' in strategy_name_list + assert '[S0, R, R, R]_17' in strategy_name_list + assert '[R, R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R]_20' in strategy_name_list assert '[R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01]_22' in strategy_name_list + assert '[R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R]_0' in strategy_name_list + assert '[R, R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01]_4' in strategy_name_list @run_on_environment_flag(name='AUTO_PARALLEL') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py index d07d2f76c..8a96ac0d6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -196,54 +196,57 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): if model_cls.__name__ == 'LinearViewModel': if tgt_shape == (32, 4, 64, 16, 4): - assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_0' in strategy_name_list - assert '[R, S0, R, S1] -> FULLY REPLICATED_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_3' in strategy_name_list - assert '[R, S1, R, S0] -> FULLY REPLICATED_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_5' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1, R]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0, R]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1, R]_17' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> FULLY REPLICATED_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01, R]_22' in strategy_name_list + for strategy in strategy_name_list: + print(strategy) + # print(strategy_name_list) + assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_11' in strategy_name_list + assert '[R, S0, R, S1] -> FULLY REPLICATED_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_14' in strategy_name_list + assert '[R, S1, R, S0] -> FULLY REPLICATED_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1, R]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0, R]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1, R]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> FULLY REPLICATED_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01, R]_4' in strategy_name_list if tgt_shape == (8, 4, 4, 64, 16, 4): - assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_0' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_3' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_5' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_17' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_22' in strategy_name_list + assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_4' in strategy_name_list @run_on_environment_flag(name='AUTO_PARALLEL')