ColossalAI/docs/source/en/features/zerobubble_pipeline_parallelism.md
duanjunwen a9bedc7a43
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv

* [feat] support chatglm2, command, deepseek for zbv

* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper

* [feat] support GPT2FusedLinearConv1D

* [feat] support GPT2FusedLinear (without tp)

* [fix] debug FusedConvLinear

* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.

* [Shardformer] support FusedLinear1D base for zbv

* [shardformer] support zbv in FusedLinear1D base, Col, Row

* [shardformer] support zbv in blip2 and sam policy

* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;

* [fix] fix incorrect number of gradients ;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Shardformer] add en doc for zbv;

* [fix] fix typo in Model compatibility table

* [fix] fix API Reference typo

* [Shardformer] add zh-Han doc for zbv

* [fix] fix Linear name; update en & zh doc

* [fix] fix shardformer doc import err

* [fix] fix shardconfig import in doc

* [fix] fix shardformer doc

* [fix] fix shardconfig doc

* [fix] fix config

* [fix] remove shardconfig

* [fix] fix doc

* [feat] add zbv doc string

* [fix] rm doc

* [fix] fix doc

* [fix] empty zbv doc

* [fix] ifx torch version

* [fix] fix torch version

* [fix] fix torch versions

* [fix] fix torch versions

* [fix] fix pyramid versions

* [fix] fix pyramid, zope version

* [fix] try fix workflow

* [fix] try import ShardConfig in yml

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix ci

* [fix] fix zbv doc

* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;

* [fix] fix policy use fused_linear

* [fix] fix weight grad none, err caused by  weight ptr change

* [fix] fix comm in WeightGradStore

* [fix] fix WeightGradStore pop param

* [fix] remove useless param in doc; fix gpt2 qkv test;

* [shardformer] simplify execute_w_pass_grad_accum;

* [fix] rm useless comments

* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass

* [shardformer] Run meaningful doc test

* [shadformer] fix doc test cmd;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 10:22:26 +08:00

8.1 KiB

ZeroBubble Pipeline Parallelism

Author: Junwen Duan, Hongxin Liu

Related Paper

Introduction

ZeroBubble (V Schedule): Crucially, splitting B into two stages (also known as an activation gradient and a weight gradient) and a scheme like 1F1B1W can further reduce the bubble compared to the 1F1B scheme in earlier work.

Hands-On Practice

We now demonstrate how to use ZeroBubble with booster API with 4 GPUs.

step 1. Import libraries

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel

import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler

step 2. Initialize Distributed Environment and Parallism Group

colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")

step 3. Initialize Module, Optimizer, and Pipeline Schedule

Build our model and Optimizer. We created a Llama with 8 Decoder-Layer. Then, inite the PipelineGraph and Pipeline schedule by get_v_schedule() function.

# Global Param
NUM_BATCH = 8
NUM_TOK_PER_BATCH = 4
NUM_LAYERS = 8
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
# Init Llama from huggingface
configuration = LlamaConfig(
    hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
    intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
    num_hidden_layers=NUM_LAYERS,
    num_attention_heads=NUM_HEADS,
    num_key_value_heads=NUM_HEADS,
    attn_implementation="flash_attention_2",
)
model = LlamaModel(configuration).cuda()
optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)

step 4. Initialize Module, Optimizer, and Pipeline Schedul

Then, we need to create the PipelineGraph and PipelineSchedule using the get_v_schedule() function. We need to initialise the PipelineGraph with the following parameters. x_cost represents the runtime consumed by operation x of each model chunk. x_mem represents the amount of memory consumed by the operation x of each model chunk. These parameters are estimated and filled in before the pipeline starts. In fact, better results can be obtained based on the runtime and memory cost during the real computation of the model. In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1.

# Init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
    n_stage=pp_size,
    n_micro=num_microbatches,
    f_cost=1,
    b_cost=1,
    w_cost=1,
    c_cost=1,
    f_mem=mem_f,
    b_mem=mem_b,
    w_mem=mem_w,
)
zbv_schedule = graph.get_v_schedule()

step 5.Init Booster

Pass pp_style="zbv" when initialising the Plugin to use the ZeroBubble Pipeline.

plugin = HybridParallelPlugin(
    pp_size=4,
    num_microbatches=4,
    tp_size=1,
    sp_size=1,
    zero_stage=1,
    initial_scale=1,
    find_unused_parameters=True,
    pp_style="zbv",
    scheduler_nodes=zbv_schedule,
    num_model_chunks=2,
)

dp_size = plugin.dp_size
booster = Booster(plugin=plugin)

step 6.Train Your Model

steps = 10
for step in range(steps):
    input_embeddings = torch.rand(
        NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
    ).cuda()
    dist.all_reduce(
        input_embeddings, group=plugin.pp_group
    )
    data_iter = iter([{"inputs_embeds": input_embeddings}])
    output = booster.execute_pipeline(
        data_iter,
        model,
        lambda x, y: x.last_hidden_state.mean(),
        optimizer,
        return_loss=True,
        return_outputs=True,
    )
    optimizer.step()
    optimizer.zero_grad()

Advanced Practice

In ColossalAI, you can get better training performance by using MetaCache and HybridParallel with ZeroBubble.

1.Use MetaCache with ZeroBubble

Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline.

plugin = HybridParallelPlugin(
    pp_size=2,
    num_microbatches=4,
    tp_size=2,
    sp_size=2,
    zero_stage=1,
    initial_scale=1,
    enable_metadata_cache=True,
    find_unused_parameters=True,
    pp_style="zbv",
    scheduler_nodes=zbv_schedule,
    num_model_chunks=2,
)

2.HybridParallel with ZeroBubble

Pass pp_size, tp_size, sp_size when initialising the Plugin to use the HybridParallel with ZeroBubble Pipeline.

plugin = HybridParallelPlugin(
    pp_size=2,
    num_microbatches=2,
    tp_size=2,
    sp_size=2,
    zero_stage=1,
    initial_scale=1,
    find_unused_parameters=True,
    pp_style="zbv",
    scheduler_nodes=zbv_schedule,
    num_model_chunks=2,
)

Performance Benchmark

HybridParallel Strategy Pipeline Parallel Sequence Parallel + Pipeline Parallel Data Parallel + Pipeline Parallel
With 1F1B 15.27 samples/sec 17.22 samples/sec 14.06 samples/sec
With Zero Bubble 17.36 samples/sec 18.38 samples/sec 14.44 samples/sec

3.Fine-tuning Scheduler parameters

Model compatibility

Shardformer/Model Bert Blip2 Bloom Chatglm2 Command Deepseek Falcon GPT2 Gptj Llama Mistral Opt Qwen2 Sam T5 Vit Whisper
ZeroBubble ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️

API Reference

{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }}