ColossalAI/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md
duanjunwen a9bedc7a43
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv

* [feat] support chatglm2, command, deepseek for zbv

* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper

* [feat] support GPT2FusedLinearConv1D

* [feat] support GPT2FusedLinear (without tp)

* [fix] debug FusedConvLinear

* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.

* [Shardformer] support FusedLinear1D base for zbv

* [shardformer] support zbv in FusedLinear1D base, Col, Row

* [shardformer] support zbv in blip2 and sam policy

* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;

* [fix] fix incorrect number of gradients ;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Shardformer] add en doc for zbv;

* [fix] fix typo in Model compatibility table

* [fix] fix API Reference typo

* [Shardformer] add zh-Han doc for zbv

* [fix] fix Linear name; update en & zh doc

* [fix] fix shardformer doc import err

* [fix] fix shardconfig import in doc

* [fix] fix shardformer doc

* [fix] fix shardconfig doc

* [fix] fix config

* [fix] remove shardconfig

* [fix] fix doc

* [feat] add zbv doc string

* [fix] rm doc

* [fix] fix doc

* [fix] empty zbv doc

* [fix] ifx torch version

* [fix] fix torch version

* [fix] fix torch versions

* [fix] fix torch versions

* [fix] fix pyramid versions

* [fix] fix pyramid, zope version

* [fix] try fix workflow

* [fix] try import ShardConfig in yml

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix ci

* [fix] fix zbv doc

* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;

* [fix] fix policy use fused_linear

* [fix] fix weight grad none, err caused by  weight ptr change

* [fix] fix comm in WeightGradStore

* [fix] fix WeightGradStore pop param

* [fix] remove useless param in doc; fix gpt2 qkv test;

* [shardformer] simplify execute_w_pass_grad_accum;

* [fix] rm useless comments

* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass

* [shardformer] Run meaningful doc test

* [shadformer] fix doc test cmd;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 10:22:26 +08:00

7.9 KiB
Raw Blame History

零气泡流水线并行

作者: Junwen Duan, Hongxin Liu

相关论文

介绍

零气泡V Schedule 与早期工作中的1F1B方案相比零气泡流水线并行将B分成两个阶段也称为激活梯度和权重梯度形如1F1B1W这样的方案可以进一步减少气泡。

使用

我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble

step 1. 引用仓库

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel

import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler

step 2. 初始化分布式环境

colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")

step 3. 初始化模型优化器

建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。

# Global Param
NUM_BATCH = 8
NUM_TOK_PER_BATCH = 4
NUM_LAYERS = 8
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
# Init Llama from huggingface
configuration = LlamaConfig(
    hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
    intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
    num_hidden_layers=NUM_LAYERS,
    num_attention_heads=NUM_HEADS,
    num_key_value_heads=NUM_HEADS,
    attn_implementation="flash_attention_2",
)
model = LlamaModel(configuration).cuda()
optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)

step 4.初始化流水线Schedule

然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。 x_cost 表示每个模型块的操作 x 所消耗的运行时间。 x_mem 表示每个模型块的操作 x 所消耗的内存量。 这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。 在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1p2p 通信时间为 1。

# Init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
    n_stage=pp_size,
    n_micro=num_microbatches,
    f_cost=1,
    b_cost=1,
    w_cost=1,
    c_cost=1,
    f_mem=mem_f,
    b_mem=mem_b,
    w_mem=mem_w,
)
zbv_schedule = graph.get_v_schedule()

step 5.初始化Booster

在初始化Plugin时输入pp_style="zbv"以使用ZeroBubble流水线并行。

plugin = HybridParallelPlugin(
    pp_size=4,
    num_microbatches=4,
    tp_size=1,
    sp_size=1,
    zero_stage=1,
    initial_scale=1,
    find_unused_parameters=True,
    pp_style="zbv",
    scheduler_nodes=zbv_schedule,
    num_model_chunks=2,
)

dp_size = plugin.dp_size
booster = Booster(plugin=plugin)

step 6.训练模型

steps = 10
for step in range(steps):
    input_embeddings = torch.rand(
        NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
    ).cuda()
    dist.all_reduce(
        input_embeddings, group=plugin.pp_group
    )
    data_iter = iter([{"inputs_embeds": input_embeddings}])
    output = booster.execute_pipeline(
        data_iter,
        model,
        lambda x, y: x.last_hidden_state.mean(),
        optimizer,
        return_loss=True,
        return_outputs=True,
    )
    optimizer.step()
    optimizer.zero_grad()

进阶使用技巧

在 ColossalAI 中通过使用MetaCache和混合并行的ZeroBubble可以获得更好的训练性能。

1.在ZeroBubble中使用元数据缓存

在初始化Plugin时输入 "enable_metadata_cache=True"以便在ZeroBubble管道中使用元数据缓存。

plugin = HybridParallelPlugin(
    pp_size=2,
    num_microbatches=4,
    tp_size=2,
    sp_size=2,
    zero_stage=1,
    initial_scale=1,
    enable_metadata_cache=True,
    find_unused_parameters=True,
    pp_style="zbv",
    scheduler_nodes=zbv_schedule,
    num_model_chunks=2,
)

2.同时使用ZeroBubble和混合并行

在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道HybridParallel with ZeroBubble Pipeline

plugin = HybridParallelPlugin(
    pp_size=2,
    num_microbatches=2,
    tp_size=2,
    sp_size=2,
    zero_stage=1,
    initial_scale=1,
    find_unused_parameters=True,
    pp_style="zbv",
    scheduler_nodes=zbv_schedule,
    num_model_chunks=2,
)

性能指标

HybridParallel Strategy Pipeline Parallel Sequence Parallel + Pipeline Parallel Data Parallel + Pipeline Parallel
With 1F1B 15.27 samples/sec 17.22 samples/sec 14.06 samples/sec
With Zero Bubble 17.36 samples/sec 18.38 samples/sec 14.44 samples/sec

模型兼容性

Shardformer/Model Bert Blip2 Bloom Chatglm2 Command Deepseek Falcon GPT2 Gptj Llama Mistral Opt Qwen2 Sam T5 Vit Whisper
ZeroBubble ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️

API 参考

{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }}