[Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Edenzzzz
2024-08-16 13:56:38 +08:00
committed by GitHub
parent 887d2d579b
commit f5c84af0b0
50 changed files with 1870 additions and 326 deletions

View File

@@ -6,6 +6,7 @@ import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
@@ -107,13 +108,13 @@ def run_pp(
# check loss
if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"])
assert_close(torch_loss, pp_ret["loss"])
# check gradients
for i in range(num_model_chunk):
idx = world_size * i + rank
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step
torch_optimizer.step()
@@ -123,8 +124,8 @@ def run_pp(
# check updated param
for i in range(num_model_chunk):
idx = world_size * i + rank
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
@@ -135,14 +136,14 @@ def run_pp(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"])
assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
@pytest.mark.dist

View File

@@ -6,6 +6,7 @@ import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
@@ -103,13 +104,13 @@ def examine_pp(num_microbatch: int, batch_size: int):
# check loss
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])
assert_close(torch_loss, pp_ret["loss"])
# check gradients
for i in range(len(sharded_model)):
idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step
torch_optimizer.step()
@@ -119,8 +120,8 @@ def examine_pp(num_microbatch: int, batch_size: int):
# check updated param
for i in range(len(sharded_model)):
idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
@@ -131,14 +132,14 @@ def examine_pp(num_microbatch: int, batch_size: int):
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])
assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
def run_dist(