mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[Feature] Zigzag Ring attention (#5905)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
@@ -107,13 +108,13 @@ def run_pp(
|
||||
|
||||
# check loss
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
assert_close(torch_loss, pp_ret["loss"])
|
||||
|
||||
# check gradients
|
||||
for i in range(num_model_chunk):
|
||||
idx = world_size * i + rank
|
||||
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
|
||||
# step
|
||||
torch_optimizer.step()
|
||||
@@ -123,8 +124,8 @@ def run_pp(
|
||||
# check updated param
|
||||
for i in range(num_model_chunk):
|
||||
idx = world_size * i + rank
|
||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
|
||||
# forward only
|
||||
with torch.no_grad():
|
||||
@@ -135,14 +136,14 @@ def run_pp(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||
)
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
assert_close(torch_loss, pp_ret["loss"])
|
||||
|
||||
for layer in sharded_model:
|
||||
if layer.weight.grad is None:
|
||||
assert layer.weight.grad is None and layer.bias.grad is None
|
||||
else:
|
||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
@@ -103,13 +104,13 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||
|
||||
# check loss
|
||||
if stage_manager.is_last_stage():
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
assert_close(torch_loss, pp_ret["loss"])
|
||||
|
||||
# check gradients
|
||||
for i in range(len(sharded_model)):
|
||||
idx = rank * num_local_layer + i
|
||||
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
|
||||
# step
|
||||
torch_optimizer.step()
|
||||
@@ -119,8 +120,8 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||
# check updated param
|
||||
for i in range(len(sharded_model)):
|
||||
idx = rank * num_local_layer + i
|
||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
|
||||
# forward only
|
||||
with torch.no_grad():
|
||||
@@ -131,14 +132,14 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
assert_close(torch_loss, pp_ret["loss"])
|
||||
|
||||
for layer in sharded_model:
|
||||
if layer.weight.grad is None:
|
||||
assert layer.weight.grad is None and layer.bias.grad is None
|
||||
else:
|
||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
|
||||
|
||||
def run_dist(
|
||||
|
Reference in New Issue
Block a user