diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 711439955..5137494ad 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -338,7 +338,7 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: Returns: torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo. """ - device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') interp = MetaInfoProp(gm.to(device)) if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor diff --git a/examples/tutorial/auto_parallel/README.md b/examples/tutorial/auto_parallel/README.md index bed488022..e93a8288b 100644 --- a/examples/tutorial/auto_parallel/README.md +++ b/examples/tutorial/auto_parallel/README.md @@ -15,3 +15,82 @@ export DATA=/path/to/data ```bash colossalai run --nproc_per_node 4 auto_parallel_demo.py ``` + +## Auto Checkpoint Benchmarking + +We prepare three demos for you to test the performance of auto checkpoint, the test `demo_resnet50.py` and `demo_gpt2_medium.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget. + +The usage of the above two test +```bash +python demo_resnet50.py --help +usage: ResNet50 Auto Activation Benchmark [-h] [--batch_size BATCH_SIZE] [--num_steps NUM_STEPS] [--sample_points SAMPLE_POINTS] [--free_memory FREE_MEMORY] + [--start_factor START_FACTOR] + +optional arguments: + -h, --help show this help message and exit + --batch_size BATCH_SIZE + batch size for benchmark, default 128 + --num_steps NUM_STEPS + number of test steps for benchmark, default 5 + --sample_points SAMPLE_POINTS + number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15 + --free_memory FREE_MEMORY + maximum memory budget in MB for benchmark, default 11000 MB + --start_factor START_FACTOR + start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 4 + +# run with default settings +python demo_resnet50.py + +python demo_gpt2_medium.py --help +usage: GPT2 medium Auto Activation Benchmark [-h] [--batch_size BATCH_SIZE] [--num_steps NUM_STEPS] [--sample_points SAMPLE_POINTS] [--free_memory FREE_MEMORY] + [--start_factor START_FACTOR] + +optional arguments: + -h, --help show this help message and exit + --batch_size BATCH_SIZE + batch size for benchmark, default 8 + --num_steps NUM_STEPS + number of test steps for benchmark, default 5 + --sample_points SAMPLE_POINTS + number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15 + --free_memory FREE_MEMORY + maximum memory budget in MB for benchmark, default 56000 MB + --start_factor START_FACTOR + start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 10 + +# run with default settings +python demo_gpt2_medium.py +``` + +There are some results for your reference + +### ResNet 50 +![](./imgs/resnet50_benchmark.png) + +### GPT2 Medium +![](./imgs/gpt2_benchmark.png) + +We also prepare the demo `demo_resnet152.py` to manifest the benefit of auto activation with large batch, the usage is listed as follows +```bash +python demo_resnet152.py --help +usage: ResNet152 Auto Activation Through Put Benchmark [-h] [--num_steps NUM_STEPS] + +optional arguments: + -h, --help show this help message and exit + --num_steps NUM_STEPS + number of test steps for benchmark, default 5 + +# run with default settings +python demo_resnet152.py +``` + +here are some results on our end for your reference +```bash +===============test summary================ +batch_size: 512, peak memory: 73314.392 MB, through put: 254.286 images/s +batch_size: 1024, peak memory: 73316.216 MB, through put: 397.608 images/s +batch_size: 2048, peak memory: 72927.837 MB, through put: 277.429 images/s +``` + +The above tests will output the test summary and a plot of the benchmarking results. diff --git a/examples/tutorial/auto_parallel/auto_ckpt_demo.ipynb b/examples/tutorial/auto_parallel/auto_ckpt_demo.ipynb deleted file mode 100644 index cacf5d5f3..000000000 --- a/examples/tutorial/auto_parallel/auto_ckpt_demo.ipynb +++ /dev/null @@ -1,878 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/lcsjy/.conda/envs/autoparallel/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "data": { - "text/html": [ - "
[11/10/22 18:04:14] INFO     colossalai - torch.distributed.distributed_c10d - INFO: Added key:                    \n",
-       "                             store_based_barrier_key:1 to store for rank: 0                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m[11/10/22 18:04:14]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Added key: \n", - "\u001b[2;36m \u001b[0m store_based_barrier_key:\u001b[1;36m1\u001b[0m to store for rank: \u001b[1;36m0\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Rank 0: Completed store-based \n",
-       "                             barrier for key:store_based_barrier_key:1 with 1 nodes.                               \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Rank \u001b[1;36m0\u001b[0m: Completed store-based \n", - "\u001b[2;36m \u001b[0m barrier for key:store_based_barrier_key:\u001b[1;36m1\u001b[0m with \u001b[1;36m1\u001b[0m nodes. \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Added key:                    \n",
-       "                             store_based_barrier_key:2 to store for rank: 0                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Added key: \n", - "\u001b[2;36m \u001b[0m store_based_barrier_key:\u001b[1;36m2\u001b[0m to store for rank: \u001b[1;36m0\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Rank 0: Completed store-based \n",
-       "                             barrier for key:store_based_barrier_key:2 with 1 nodes.                               \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Rank \u001b[1;36m0\u001b[0m: Completed store-based \n", - "\u001b[2;36m \u001b[0m barrier for key:store_based_barrier_key:\u001b[1;36m2\u001b[0m with \u001b[1;36m1\u001b[0m nodes. \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Added key:                    \n",
-       "                             store_based_barrier_key:3 to store for rank: 0                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Added key: \n", - "\u001b[2;36m \u001b[0m store_based_barrier_key:\u001b[1;36m3\u001b[0m to store for rank: \u001b[1;36m0\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Rank 0: Completed store-based \n",
-       "                             barrier for key:store_based_barrier_key:3 with 1 nodes.                               \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Rank \u001b[1;36m0\u001b[0m: Completed store-based \n", - "\u001b[2;36m \u001b[0m barrier for key:store_based_barrier_key:\u001b[1;36m3\u001b[0m with \u001b[1;36m1\u001b[0m nodes. \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Added key:                    \n",
-       "                             store_based_barrier_key:4 to store for rank: 0                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Added key: \n", - "\u001b[2;36m \u001b[0m store_based_barrier_key:\u001b[1;36m4\u001b[0m to store for rank: \u001b[1;36m0\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Rank 0: Completed store-based \n",
-       "                             barrier for key:store_based_barrier_key:4 with 1 nodes.                               \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Rank \u001b[1;36m0\u001b[0m: Completed store-based \n", - "\u001b[2;36m \u001b[0m barrier for key:store_based_barrier_key:\u001b[1;36m4\u001b[0m with \u001b[1;36m1\u001b[0m nodes. \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Added key:                    \n",
-       "                             store_based_barrier_key:5 to store for rank: 0                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Added key: \n", - "\u001b[2;36m \u001b[0m store_based_barrier_key:\u001b[1;36m5\u001b[0m to store for rank: \u001b[1;36m0\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Rank 0: Completed store-based \n",
-       "                             barrier for key:store_based_barrier_key:5 with 1 nodes.                               \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Rank \u001b[1;36m0\u001b[0m: Completed store-based \n", - "\u001b[2;36m \u001b[0m barrier for key:store_based_barrier_key:\u001b[1;36m5\u001b[0m with \u001b[1;36m1\u001b[0m nodes. \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Added key:                    \n",
-       "                             store_based_barrier_key:6 to store for rank: 0                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Added key: \n", - "\u001b[2;36m \u001b[0m store_based_barrier_key:\u001b[1;36m6\u001b[0m to store for rank: \u001b[1;36m0\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Rank 0: Completed store-based \n",
-       "                             barrier for key:store_based_barrier_key:6 with 1 nodes.                               \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Rank \u001b[1;36m0\u001b[0m: Completed store-based \n", - "\u001b[2;36m \u001b[0m barrier for key:store_based_barrier_key:\u001b[1;36m6\u001b[0m with \u001b[1;36m1\u001b[0m nodes. \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Added key:                    \n",
-       "                             store_based_barrier_key:7 to store for rank: 0                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Added key: \n", - "\u001b[2;36m \u001b[0m store_based_barrier_key:\u001b[1;36m7\u001b[0m to store for rank: \u001b[1;36m0\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Rank 0: Completed store-based \n",
-       "                             barrier for key:store_based_barrier_key:7 with 1 nodes.                               \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Rank \u001b[1;36m0\u001b[0m: Completed store-based \n", - "\u001b[2;36m \u001b[0m barrier for key:store_based_barrier_key:\u001b[1;36m7\u001b[0m with \u001b[1;36m1\u001b[0m nodes. \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Added key:                    \n",
-       "                             store_based_barrier_key:8 to store for rank: 0                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Added key: \n", - "\u001b[2;36m \u001b[0m store_based_barrier_key:\u001b[1;36m8\u001b[0m to store for rank: \u001b[1;36m0\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - torch.distributed.distributed_c10d - INFO: Rank 0: Completed store-based \n",
-       "                             barrier for key:store_based_barrier_key:8 with 1 nodes.                               \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - torch.distributed.distributed_c10d - INFO: Rank \u001b[1;36m0\u001b[0m: Completed store-based \n", - "\u001b[2;36m \u001b[0m barrier for key:store_based_barrier_key:\u001b[1;36m8\u001b[0m with \u001b[1;36m1\u001b[0m nodes. \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - colossalai - INFO:                                                       \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/context/parallel_context.py:521 set_device          \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - colossalai - INFO: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/context/\u001b[0m\u001b[95mparallel_context.py\u001b[0m:\u001b[1;36m521\u001b[0m set_device \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - colossalai - INFO: process rank 0 is bound to device 0                   \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - colossalai - INFO: process rank \u001b[1;36m0\u001b[0m is bound to device \u001b[1;36m0\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - colossalai - INFO:                                                       \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/context/parallel_context.py:557 set_seed            \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - colossalai - INFO: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/context/\u001b[0m\u001b[95mparallel_context.py\u001b[0m:\u001b[1;36m557\u001b[0m set_seed \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - colossalai - INFO: initialized seed on rank 0, numpy: 1024, python       \n",
-       "                             random: 1024, ParallelMode.DATA: 1024, ParallelMode.TENSOR: 1024,the default parallel \n",
-       "                             seed is ParallelMode.DATA.                                                            \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - colossalai - INFO: initialized seed on rank \u001b[1;36m0\u001b[0m, numpy: \u001b[1;36m1024\u001b[0m, python \n", - "\u001b[2;36m \u001b[0m random: \u001b[1;36m1024\u001b[0m, ParallelMode.DATA: \u001b[1;36m1024\u001b[0m, ParallelMode.TENSOR: \u001b[1;36m1024\u001b[0m,the default parallel \n", - "\u001b[2;36m \u001b[0m seed is ParallelMode.DATA. \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - colossalai - INFO: /home/lcsjy/ColossalAI/colossalai/initialize.py:117   \n",
-       "                             launch                                                                                \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - colossalai - INFO: \u001b[35m/home/lcsjy/ColossalAI/colossalai/\u001b[0m\u001b[95minitialize.py\u001b[0m:\u001b[1;36m117\u001b[0m \n", - "\u001b[2;36m \u001b[0m launch \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    INFO     colossalai - colossalai - INFO: Distributed environment is initialized, data parallel \n",
-       "                             size: 1, pipeline parallel size: 1, tensor parallel size: 1                           \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - colossalai - INFO: Distributed environment is initialized, data parallel \n", - "\u001b[2;36m \u001b[0m size: \u001b[1;36m1\u001b[0m, pipeline parallel size: \u001b[1;36m1\u001b[0m, tensor parallel size: \u001b[1;36m1\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import time\n", - "import torchvision.models as tm\n", - "import torch\n", - "import colossalai\n", - "from colossalai.fx import symbolic_trace, metainfo_trace\n", - "from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor\n", - "from functools import partial\n", - "from colossalai.utils import free_port\n", - "\n", - "from bench_utils import bench, bench_rotor\n", - "import matplotlib.pyplot as plt\n", - "\n", - "colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### ResNet152 with batch size = 512 fails" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(78990.4404296875, inf)" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def data_gen(batch_size, shape, device='cuda'):\n", - " data = torch.empty(batch_size, *shape, device=device)\n", - " label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)\n", - " return {'x': data}, label\n", - "\n", - "model = tm.resnet152()\n", - "gm = symbolic_trace(model)\n", - "gm = metainfo_trace(gm, torch.empty(512, 3, 224, 224, device='meta'))\n", - "bench(gm, torch.nn.CrossEntropyLoss(), partial(data_gen, batch_size=512, shape=(3, 224, 224)), num_steps=5)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### ResNet152 with batch size = 2048 succeeds " - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(74495.8486328125, 5634.262561798096)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def data_gen(batch_size, shape, device='cuda'):\n", - " data = torch.empty(batch_size, *shape, device=device)\n", - " label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)\n", - " return {'x': data}, label\n", - "\n", - "model = tm.resnet152()\n", - "gm = symbolic_trace(model)\n", - "gm = metainfo_trace(gm, torch.empty(2048, 3, 224, 224, device='meta'))\n", - "solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0] * 0.95)\n", - "gm.graph = solver.solve()\n", - "bench(gm, torch.nn.CrossEntropyLoss(), partial(data_gen, batch_size=2048, shape=(3, 224, 224)), num_steps=5)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Benchmarking on ResNet18" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
[11/10/22 18:04:20] WARNING  colossalai - colossalai - WARNING:                                                    \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py:82    \n",
-       "                             solve                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m[11/10/22 18:04:20]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/\u001b[0m\u001b[95mckpt_solver_rotor.py\u001b[0m:\u001b[1;36m82\u001b[0m \n", - "\u001b[2;36m \u001b[0m solve \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this     \n",
-       "                             chain from index 0 to 14 with memory 500                                              \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this \n", - "\u001b[2;36m \u001b[0m chain from index \u001b[1;36m0\u001b[0m to \u001b[1;36m14\u001b[0m with memory \u001b[1;36m500\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING:                                                    \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py:82    \n",
-       "                             solve                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/\u001b[0m\u001b[95mckpt_solver_rotor.py\u001b[0m:\u001b[1;36m82\u001b[0m \n", - "\u001b[2;36m \u001b[0m solve \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this     \n",
-       "                             chain from index 0 to 14 with memory 500                                              \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this \n", - "\u001b[2;36m \u001b[0m chain from index \u001b[1;36m0\u001b[0m to \u001b[1;36m14\u001b[0m with memory \u001b[1;36m500\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING:                                                    \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py:82    \n",
-       "                             solve                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/\u001b[0m\u001b[95mckpt_solver_rotor.py\u001b[0m:\u001b[1;36m82\u001b[0m \n", - "\u001b[2;36m \u001b[0m solve \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this     \n",
-       "                             chain from index 0 to 14 with memory 500                                              \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this \n", - "\u001b[2;36m \u001b[0m chain from index \u001b[1;36m0\u001b[0m to \u001b[1;36m14\u001b[0m with memory \u001b[1;36m500\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING:                                                    \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py:82    \n",
-       "                             solve                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/\u001b[0m\u001b[95mckpt_solver_rotor.py\u001b[0m:\u001b[1;36m82\u001b[0m \n", - "\u001b[2;36m \u001b[0m solve \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this     \n",
-       "                             chain from index 0 to 14 with memory 500                                              \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this \n", - "\u001b[2;36m \u001b[0m chain from index \u001b[1;36m0\u001b[0m to \u001b[1;36m14\u001b[0m with memory \u001b[1;36m500\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
[11/10/22 18:04:21] WARNING  colossalai - colossalai - WARNING:                                                    \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py:82    \n",
-       "                             solve                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m[11/10/22 18:04:21]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/\u001b[0m\u001b[95mckpt_solver_rotor.py\u001b[0m:\u001b[1;36m82\u001b[0m \n", - "\u001b[2;36m \u001b[0m solve \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this     \n",
-       "                             chain from index 0 to 14 with memory 500                                              \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this \n", - "\u001b[2;36m \u001b[0m chain from index \u001b[1;36m0\u001b[0m to \u001b[1;36m14\u001b[0m with memory \u001b[1;36m500\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING:                                                    \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py:82    \n",
-       "                             solve                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/\u001b[0m\u001b[95mckpt_solver_rotor.py\u001b[0m:\u001b[1;36m82\u001b[0m \n", - "\u001b[2;36m \u001b[0m solve \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this     \n",
-       "                             chain from index 0 to 14 with memory 500                                              \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this \n", - "\u001b[2;36m \u001b[0m chain from index \u001b[1;36m0\u001b[0m to \u001b[1;36m14\u001b[0m with memory \u001b[1;36m500\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING:                                                    \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py:82    \n",
-       "                             solve                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/\u001b[0m\u001b[95mckpt_solver_rotor.py\u001b[0m:\u001b[1;36m82\u001b[0m \n", - "\u001b[2;36m \u001b[0m solve \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this     \n",
-       "                             chain from index 0 to 14 with memory 500                                              \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this \n", - "\u001b[2;36m \u001b[0m chain from index \u001b[1;36m0\u001b[0m to \u001b[1;36m14\u001b[0m with memory \u001b[1;36m500\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
[11/10/22 18:04:22] WARNING  colossalai - colossalai - WARNING:                                                    \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py:82    \n",
-       "                             solve                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m[11/10/22 18:04:22]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/\u001b[0m\u001b[95mckpt_solver_rotor.py\u001b[0m:\u001b[1;36m82\u001b[0m \n", - "\u001b[2;36m \u001b[0m solve \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this     \n",
-       "                             chain from index 0 to 14 with memory 500                                              \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this \n", - "\u001b[2;36m \u001b[0m chain from index \u001b[1;36m0\u001b[0m to \u001b[1;36m14\u001b[0m with memory \u001b[1;36m500\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING:                                                    \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py:82    \n",
-       "                             solve                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/\u001b[0m\u001b[95mckpt_solver_rotor.py\u001b[0m:\u001b[1;36m82\u001b[0m \n", - "\u001b[2;36m \u001b[0m solve \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this     \n",
-       "                             chain from index 0 to 14 with memory 500                                              \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this \n", - "\u001b[2;36m \u001b[0m chain from index \u001b[1;36m0\u001b[0m to \u001b[1;36m14\u001b[0m with memory \u001b[1;36m500\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
[11/10/22 18:04:23] WARNING  colossalai - colossalai - WARNING:                                                    \n",
-       "                             /home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py:82    \n",
-       "                             solve                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m[11/10/22 18:04:23]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: \n", - "\u001b[2;36m \u001b[0m \u001b[35m/home/lcsjy/ColossalAI/colossalai/auto_parallel/checkpoint/\u001b[0m\u001b[95mckpt_solver_rotor.py\u001b[0m:\u001b[1;36m82\u001b[0m \n", - "\u001b[2;36m \u001b[0m solve \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                    WARNING  colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this     \n",
-       "                             chain from index 0 to 14 with memory 500                                              \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m colossalai - colossalai - WARNING: Checkpoint solver failed: Can not process this \n", - "\u001b[2;36m \u001b[0m chain from index \u001b[1;36m0\u001b[0m to \u001b[1;36m14\u001b[0m with memory \u001b[1;36m500\u001b[0m \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "def data_gen(batch_size, shape, device='cuda'):\n", - " data = torch.empty(batch_size, *shape, device=device)\n", - " label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)\n", - " return (data, ), label\n", - "\n", - "model = tm.resnet18()\n", - "gm = symbolic_trace(model)\n", - "gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta'))\n", - "peak_hist, step_hist = bench_rotor(gm, torch.nn.CrossEntropyLoss(), partial(data_gen, batch_size=128, shape=(3, 224, 224)), num_steps=5, sample_points=20, free_memory=2700 * 1024**2)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plt.figure(figsize=(8, 8))\n", - "plt.plot(peak_hist, step_hist)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[540.0,\n", - " 653.6842105263158,\n", - " 767.3684210526316,\n", - " 881.0526315789474,\n", - " 994.7368421052631,\n", - " 1108.421052631579,\n", - " 1222.1052631578948,\n", - " 1335.7894736842104,\n", - " 1449.4736842105262,\n", - " 1563.157894736842,\n", - " 26711.86572265625,\n", - " 26711.86572265625,\n", - " 26711.86572265625,\n", - " 26711.86572265625,\n", - " 26711.86572265625,\n", - " 26711.86572265625,\n", - " 26711.86572265625,\n", - " 26711.86572265625,\n", - " 26711.86572265625,\n", - " 26711.86572265625]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "peak_hist" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.10.6 ('autoparallel': conda)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "cc0ad6865167fb9a52c12f0fd0c8203c9a7690797bfee612a871d56b9d2024ce" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/tutorial/auto_parallel/bench_utils.py b/examples/tutorial/auto_parallel/bench_utils.py index 365e07e21..d9d656b85 100644 --- a/examples/tutorial/auto_parallel/bench_utils.py +++ b/examples/tutorial/auto_parallel/bench_utils.py @@ -1,16 +1,33 @@ import time +from copy import deepcopy from functools import partial from typing import Callable, Tuple import numpy as np import torch +import torch.nn as nn import torchvision.models as tm +from transformers import GPT2Config, GPT2LMHeadModel from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace -def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5): +def bench(gm: torch.fx.GraphModule, + criterion: torch.nn.Module, + data_gen: Callable, + num_steps: int = 5) -> Tuple[int, int]: + """Benchmarking a given graph module + + Args: + gm (torch.fx.GraphModule): The graph module to benchmark. + criterion (torch.nn.Module): Loss function. + data_gen (Callable): Data generator. + num_steps (int, optional): Number of test steps. Defaults to 5. + + Returns: + Tuple[int, int]: peak memory in MB and step time in MS. + """ gm.train() gm.cuda() step_time = float('inf') @@ -39,7 +56,8 @@ def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callab del args, label, output, loss gm.to("cpu") torch.cuda.empty_cache() - return (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2, step_time * 1.0e3 + peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2 + return peak_mem, step_time * 1.0e3 def bench_rotor(gm: torch.fx.GraphModule, @@ -47,19 +65,92 @@ def bench_rotor(gm: torch.fx.GraphModule, data_gen: Callable, num_steps: int = 5, sample_points: int = 20, - free_memory: int = torch.cuda.mem_get_info()[0]): + free_memory: int = torch.cuda.mem_get_info()[0], + start_factor: int = 4) -> Tuple[np.array, list, list]: + """Auto Checkpoint Rotor Algorithm benchmarking + Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data. + + Args: + gm (torch.fx.GraphModule): The graph module to benchmark. + criterion (torch.nn.Module): Loss function. + data_gen (Callable): Data generator. + num_steps (int, optional): Number of test steps. Defaults to 5. + sample_points (int, optional): Number of sample points. Defaults to 20. + free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0]. + start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget + will be free_memory / start_factor. Defaults to 4. + + Returns: + Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS). + """ peak_hist, step_hist = [], [] - for budget in np.linspace(free_memory // 5, free_memory, sample_points): + raw_graph = deepcopy(gm.graph) + for budget in np.linspace(free_memory // start_factor, free_memory, sample_points): gm = metainfo_trace(gm, *data_gen()[0]) solver = CheckpointSolverRotor(gm.graph, free_memory=budget) try: - gm.graph = solver.solve() - peak_memory, step_time = bench(gm, - criterion, - partial(data_gen, batch_size=2048, shape=(3, 224, 224)), - num_steps=num_steps) + gm.graph = solver.solve(verbose=False) + peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps) except: peak_memory, step_time = budget / 1024**2, float('inf') peak_hist.append(peak_memory) step_hist.append(step_time) - return peak_hist, step_hist + gm.graph = deepcopy(raw_graph) + return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist + + +class GPTLMModel(nn.Module): + """ + GPT Model + """ + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + """ + GPT Loss + """ + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=False): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_6b(checkpoint=False): + return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) diff --git a/examples/tutorial/auto_parallel/demo_gpt2_medium.py b/examples/tutorial/auto_parallel/demo_gpt2_medium.py new file mode 100644 index 000000000..2739a4c2e --- /dev/null +++ b/examples/tutorial/auto_parallel/demo_gpt2_medium.py @@ -0,0 +1,108 @@ +import time +from argparse import ArgumentParser +from functools import partial + +import matplotlib.pyplot as plt +import torch +import torch.multiprocessing as mp +import torchvision.models as tm +from bench_utils import GPTLMLoss, bench_rotor, gpt2_medium + +import colossalai +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace, symbolic_trace +from colossalai.utils import free_port + + +def data_gen(batch_size, seq_len, vocab_size, device='cuda:0'): + """ + Generate random data for benchmarking + """ + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + return (input_ids, attention_mask), attention_mask + + +def _gpt2_benchmark(rank, world_size, port, batch_size, num_steps, sample_points, free_memory, start_factor): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = gpt2_medium() + + # trace and benchmark + data, mask = data_gen(batch_size, 1024, 50257, device='meta')[0] + gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + gm = metainfo_trace(gm, data, mask) + budgets, peak_hist, step_hist = bench_rotor(gm, + GPTLMLoss(), + partial(data_gen, batch_size=batch_size, seq_len=1024, + vocab_size=50257), + num_steps=num_steps, + sample_points=sample_points, + free_memory=free_memory, + start_factor=start_factor) + + # print summary + print("==============test summary==============") + for budget, peak, step in zip(budgets, peak_hist, step_hist): + print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS') + + # plot valid results + fig, axs = plt.subplots(1, 2, figsize=(16, 8)) + valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf"))) + + # plot peak memory vs. budget memory + axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:]) + axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--') + axs[0].set_xlabel("Budget Memory (MB)") + axs[0].set_ylabel("Peak Memory (MB)") + axs[0].set_title("Peak Memory vs. Budget Memory") + + # plot relative step time vs. budget memory + axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]]) + axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--') + axs[1].set_xlabel("Peak Memory (MB)") + axs[1].set_ylabel("Relative Step Time") + axs[1].set_title("Step Time vs. Peak Memory") + axs[1].set_ylim(0.8, 1.5) + + # save plot + fig.savefig("gpt2_benchmark.png") + + +def gpt2_benchmark(batch_size, num_steps, sample_points, free_memory, start_factor): + world_size = 1 + run_func_module = partial(_gpt2_benchmark, + world_size=world_size, + port=free_port(), + batch_size=batch_size, + num_steps=num_steps, + sample_points=sample_points, + free_memory=free_memory, + start_factor=start_factor) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == "__main__": + parser = ArgumentParser("GPT2 medium Auto Activation Benchmark") + parser.add_argument("--batch_size", type=int, default=8, help="batch size for benchmark, default 8") + parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5") + parser.add_argument( + "--sample_points", + type=int, + default=15, + help= + "number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15" + ) + parser.add_argument("--free_memory", + type=int, + default=56000, + help="maximum memory budget in MB for benchmark, default 56000 MB") + parser.add_argument( + "--start_factor", + type=int, + default=10, + help= + "start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 10" + ) + args = parser.parse_args() + + gpt2_benchmark(args.batch_size, args.num_steps, args.sample_points, args.free_memory * 1024**2, args.start_factor) diff --git a/examples/tutorial/auto_parallel/demo_resnet152.py b/examples/tutorial/auto_parallel/demo_resnet152.py new file mode 100644 index 000000000..5861371e8 --- /dev/null +++ b/examples/tutorial/auto_parallel/demo_resnet152.py @@ -0,0 +1,74 @@ +import time +from argparse import ArgumentParser +from copy import deepcopy +from functools import partial + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.multiprocessing as mp +import torchvision.models as tm +from bench_utils import bench + +import colossalai +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace, symbolic_trace +from colossalai.utils import free_port + + +def data_gen(batch_size, shape, device='cuda'): + """ + Generate random data for benchmarking + """ + data = torch.empty(batch_size, *shape, device=device) + label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) + return (data,), label + + +def _resnet152_benchmark(rank, world_size, port, num_steps): + """Resnet152 benchmark + This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of + maximum GPU memory, and with the batch size of [512, 1024, 2048] + """ + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = tm.resnet152() + gm = symbolic_trace(model) + raw_graph = deepcopy(gm.graph) + peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048] + for batch_size in batch_sizes: + batch_size = int(batch_size) + gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta')) + solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95) + gm.graph = solver.solve() + peak_mem, step_time = bench(gm, + torch.nn.CrossEntropyLoss(), + partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)), + num_steps=num_steps) + peak_mems.append(peak_mem) + through_puts.append(batch_size / step_time * 1.0e3) + gm.graph = deepcopy(raw_graph) + + # print results + print("===============test summary================") + for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts): + print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s') + + plt.plot(batch_sizes, through_puts) + plt.xlabel("batch size") + plt.ylabel("through put (images/s)") + plt.title("Resnet152 benchmark") + plt.savefig("resnet152_benchmark.png") + + +def resnet152_benchmark(num_steps): + world_size = 1 + run_func_module = partial(_resnet152_benchmark, world_size=world_size, port=free_port(), num_steps=num_steps) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == "__main__": + parser = ArgumentParser("ResNet152 Auto Activation Through Put Benchmark") + parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5") + args = parser.parse_args() + + resnet152_benchmark(args.num_steps) diff --git a/examples/tutorial/auto_parallel/demo_resnet50.py b/examples/tutorial/auto_parallel/demo_resnet50.py new file mode 100644 index 000000000..4cbd53eba --- /dev/null +++ b/examples/tutorial/auto_parallel/demo_resnet50.py @@ -0,0 +1,107 @@ +import time +from argparse import ArgumentParser +from functools import partial + +import matplotlib.pyplot as plt +import torch +import torch.multiprocessing as mp +import torchvision.models as tm +from bench_utils import bench_rotor + +import colossalai +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace, symbolic_trace +from colossalai.utils import free_port + + +def data_gen(batch_size, shape, device='cuda'): + """ + Generate random data for benchmarking + """ + data = torch.empty(batch_size, *shape, device=device) + label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) + return (data,), label + + +def _resnet50_benchmark(rank, world_size, port, batch_size, num_steps, sample_points, free_memory, start_factor): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = tm.resnet50() + + # trace and benchmark + gm = symbolic_trace(model) + gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta')) + budgets, peak_hist, step_hist = bench_rotor(gm, + torch.nn.CrossEntropyLoss(), + partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)), + num_steps=num_steps, + sample_points=sample_points, + free_memory=free_memory, + start_factor=start_factor) + + # print summary + print("==============test summary==============") + for budget, peak, step in zip(budgets, peak_hist, step_hist): + print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS') + + # plot valid results + fig, axs = plt.subplots(1, 2, figsize=(16, 8)) + valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf"))) + + # plot peak memory vs. budget memory + axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:]) + axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--') + axs[0].set_xlabel("Budget Memory (MB)") + axs[0].set_ylabel("Peak Memory (MB)") + axs[0].set_title("Peak Memory vs. Budget Memory") + + # plot relative step time vs. budget memory + axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]]) + axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--') + axs[1].set_xlabel("Peak Memory (MB)") + axs[1].set_ylabel("Relative Step Time") + axs[1].set_title("Step Time vs. Peak Memory") + axs[1].set_ylim(0.8, 1.5) + + # save plot + fig.savefig("resnet50_benchmark.png") + + +def resnet50_benchmark(batch_size, num_steps, sample_points, free_memory, start_factor): + world_size = 1 + run_func_module = partial(_resnet50_benchmark, + world_size=world_size, + port=free_port(), + batch_size=batch_size, + num_steps=num_steps, + sample_points=sample_points, + free_memory=free_memory, + start_factor=start_factor) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == "__main__": + parser = ArgumentParser("ResNet50 Auto Activation Benchmark") + parser.add_argument("--batch_size", type=int, default=128, help="batch size for benchmark, default 128") + parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5") + parser.add_argument( + "--sample_points", + type=int, + default=15, + help= + "number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15" + ) + parser.add_argument("--free_memory", + type=int, + default=11000, + help="maximum memory budget in MB for benchmark, default 11000 MB") + parser.add_argument( + "--start_factor", + type=int, + default=4, + help= + "start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 4" + ) + args = parser.parse_args() + + resnet50_benchmark(args.batch_size, args.num_steps, args.sample_points, args.free_memory * 1024**2, + args.start_factor) diff --git a/examples/tutorial/auto_parallel/imgs/gpt2_benchmark.png b/examples/tutorial/auto_parallel/imgs/gpt2_benchmark.png new file mode 100644 index 000000000..eec121758 Binary files /dev/null and b/examples/tutorial/auto_parallel/imgs/gpt2_benchmark.png differ diff --git a/examples/tutorial/auto_parallel/imgs/resnet50_benchmark.png b/examples/tutorial/auto_parallel/imgs/resnet50_benchmark.png new file mode 100644 index 000000000..0208c54fb Binary files /dev/null and b/examples/tutorial/auto_parallel/imgs/resnet50_benchmark.png differ