[Sharderformer] Support zbv in Sharderformer Policy (#6150)

* [feat] Sharderformer support zbv

* [feat] support chatglm2, command, deepseek for zbv

* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper

* [feat] support GPT2FusedLinearConv1D

* [feat] support GPT2FusedLinear (without tp)

* [fix] debug FusedConvLinear

* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.

* [Shardformer] support FusedLinear1D base for zbv

* [shardformer] support zbv in FusedLinear1D base, Col, Row

* [shardformer] support zbv in blip2 and sam policy

* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;

* [fix] fix incorrect number of gradients ;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Shardformer] add en doc for zbv;

* [fix] fix typo in Model compatibility table

* [fix] fix API Reference typo

* [Shardformer] add zh-Han doc for zbv

* [fix] fix Linear name; update en & zh doc

* [fix] fix shardformer doc import err

* [fix] fix shardconfig import in doc

* [fix] fix shardformer doc

* [fix] fix shardconfig doc

* [fix] fix config

* [fix] remove shardconfig

* [fix] fix doc

* [feat] add zbv doc string

* [fix] rm doc

* [fix] fix doc

* [fix] empty zbv doc

* [fix] ifx torch version

* [fix] fix torch version

* [fix] fix torch versions

* [fix] fix torch versions

* [fix] fix pyramid versions

* [fix] fix pyramid, zope version

* [fix] try fix workflow

* [fix] try import ShardConfig in yml

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix ci

* [fix] fix zbv doc

* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;

* [fix] fix policy use fused_linear

* [fix] fix weight grad none, err caused by  weight ptr change

* [fix] fix comm in WeightGradStore

* [fix] fix WeightGradStore pop param

* [fix] remove useless param in doc; fix gpt2 qkv test;

* [shardformer] simplify execute_w_pass_grad_accum;

* [fix] rm useless comments

* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass

* [shardformer] Run meaningful doc test

* [shadformer] fix doc test cmd;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
duanjunwen
2025-01-02 10:22:26 +08:00
committed by GitHub
parent af06d162cf
commit a9bedc7a43
27 changed files with 3511 additions and 316 deletions

View File

@@ -8,7 +8,8 @@ from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import GPT2FusedLinearConv, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@@ -118,11 +119,82 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
assert_close(target_grad, linear_row.weight.grad)
def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_base = GPT2FusedLinearConv.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_base.weight.shape == torch.Size([48, 192])
assert linear_base.bias.shape == torch.Size([192])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
# ensure weights are reversibly loadable
linear_base.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_base.state_dict())
# check computation correctness
x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_base(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
# check the input gradients & weight gradients
assert_close(out.grad, gather_out.grad)
assert_close(linear.weight.grad, linear_base.weight.grad)
def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_base = GPT2FusedLinearConv.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode, use_zbv=True)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_base.weight.shape == torch.Size([48, 192])
assert linear_base.bias.shape == torch.Size([192])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
# ensure weights are reversibly loadable
linear_base.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_base.state_dict())
# check computation correctness
x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_base(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
WeightGradStore.pop(chunk=0)
# check the input gradients & weight gradients
assert_close(out.grad, gather_out.grad)
assert_close(linear.weight.grad, linear_base.weight.grad)
@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
check_linear_conv_1d_without_weight_grad_store(lazy_init, None)
check_linear_conv_1d_with_weight_grad_store(lazy_init, None)
def run_dist(rank, world_size, port):

View File

@@ -7,7 +7,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer import FusedLinear, FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@@ -120,12 +120,45 @@ def check_linear_1d_col_row(lazy_init: bool):
assert_close(target_grad2, linear_row.weight.grad)
@parameterize("lazy_init", [False, True])
def check_linear_1d_base(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(8, 80).cuda()
with ctx:
linear_copy = nn.Linear(8, 80).cuda()
linear_base = FusedLinear.from_native_module(linear_copy)
assert linear.weight.shape == torch.Size([80, 8])
assert linear.bias.shape == torch.Size([80])
assert linear_base.weight.shape == torch.Size([80, 8])
assert linear_base.bias.shape == torch.Size([80])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
# ensure weights are reversibly loadable
linear_base.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_base.state_dict())
# check computation correctness
x = torch.rand(4, 8).cuda()
out = linear(x)
base_out = linear_base(x)
assert_close(out, base_out)
# check backward correctness
out.sum().backward()
base_out.sum().backward()
assert_close(linear.weight.grad, linear_base.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_linear_1d_col()
check_linear_1d_row()
check_linear_1d_col_row()
check_linear_1d_base()
@rerun_if_address_is_in_use()