Files
ColossalAI/colossalai/zero/low_level
Zhongkai Zhao 8e412a548e [shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 17:15:47 +08:00
..
2024-01-25 13:56:27 +08:00

Low Level ZeRO

Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.

Examples of ZeRO and gradient accumulation

The code below only shows a typical gradient accumulation process, and it drops a lot of details, such as the processing of loss.

# examples of ZeRO1 with gradient accumulation
...
outputs = model(input)
loss = SomeLoss(outputs)
if (idx + 1) % ACCUMULATE_STEP != 0:
    with booster.no_sync(model, optimizer):
        # under this context, the gradient would not sync when backward,
        # left each rank having different gradient.
        # It saves the backward time
        booster.backward(loss, optimizer)
        continue
else:
    # need to sync all the accumulated gradient
    booster.backward(loss, optimizer):
    optimizer.step()
    ...
# example of ZeRO2 with gradient accumulation

...
outputs = model(input)
loss = SomeLoss(outputs)
# ZeRO2 split the gradients and can NOT accumulate gradient with syncing.
booster.backward(loss, optimizer)
if (idx + 1) % ACCUMULATE_STEP == 0:
    optimizer.step()
...

Design:

Notion

p32 denotes the param copy in the optimizer p denotes the model param g denotes the gradient

INIT

In low level zero(1, 2), p32 is split. Different from the previous implement, we split each p32 evenly by world_size. Thus, rank0 got a param list as [p00, p10], rank1 got a param list as [p-01, p-11], etc. image

For the detailed implementation, we first pad p for it can be split by world_size if needed. Then, we would view it to the shape [world_size, -1], and each rank got its own part p32 by cloning.

BWD

To leverage the communication, a gradient would be added to a bucket first. When the bucket is full, each g in it would be reshaped as [world_size, -1]. And the [local_rank] parts would be united. The data structure looks like this:

{
0: [g-00, g-10],
1: [g-01, g-11],
2: [g-02, g-12]
}

After that, the gradients would be flattened by rank, and the data structure looks like this:

# g-X0 means flatten([g-00, g-10])
{
0: [g-X0],
1: [g-X1],
2: [g-X2]
}

For zero1, we iterate the dictionary and do all_reduce. For zero2, we can just do reduce-scatter.

Optim

For each rank gets its own p32 and the counterpart g, it is quite easy to do optim.step().

However, we have to consider a situation of layer drop, for instance:

class MlpModel(nn.Module):
    def __init__(self):
        super(MlpModel, self).__init__()
        self.linear1 = nn.Linear(128, 256)
        self.drop_linear = nn.Linear(256, 256)
        self.linear2 = nn.Linear(256, 512)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

And the solution is to build a mapping of p32, p, and g. Before optim.step(), we collect p which requires_grad=True and p.grad != None as a real working param. And select the counterpart p32 and g.