mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[Feature] Zigzag Ring attention (#5905)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
|
||||
from torch.nn import Module
|
||||
from torch.optim import Adam, Optimizer
|
||||
from torch.testing import assert_close
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
@@ -259,7 +260,6 @@ def run_forward_backward_with_hybrid_plugin(
|
||||
org_output = org_model(**unshard_test_data)
|
||||
org_loss = criterion(org_output)
|
||||
org_loss.backward()
|
||||
|
||||
return org_loss, org_output, sharded_loss, sharded_output
|
||||
|
||||
|
||||
@@ -302,11 +302,12 @@ def run_forward_backward_with_low_level_zero_plugin(
|
||||
|
||||
|
||||
def check_output_hidden_state(
|
||||
org_output: Tensor,
|
||||
sharded_output: Tensor,
|
||||
org_output: BaseModelOutputWithPast,
|
||||
sharded_output: BaseModelOutputWithPast,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
):
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
@@ -315,6 +316,14 @@ def check_output_hidden_state(
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
|
||||
# Check if the output sequence is gathered before cross entropy
|
||||
if shard_config is not None:
|
||||
seq_dim = 1
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
|
||||
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
|
||||
|
||||
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@@ -374,8 +383,11 @@ def get_grad_tensors_for_check(
|
||||
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
||||
|
||||
# embedding may be resized when using tensor parallel
|
||||
if shard_grad.shape[0] > org_grad.shape[0]:
|
||||
shard_grad = shard_grad[: org_grad.shape[0], :]
|
||||
try:
|
||||
if shard_grad.shape[0] > org_grad.shape[0]:
|
||||
shard_grad = shard_grad[: org_grad.shape[0], :]
|
||||
except:
|
||||
pass
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||
|
||||
@@ -404,9 +416,6 @@ def check_grad(
|
||||
org_grad = getattr_(org_model, suffix).weight.grad
|
||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||
shard_weight = getattr_(sharded_model, suffix).weight
|
||||
# if verbose and dist.get_rank() == 0:
|
||||
# print("shard_weight", shard_weight)
|
||||
# print("org_grad", org_grad)
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
|
||||
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
||||
@@ -440,7 +449,7 @@ def check_all_grad_tensors(check_tensors):
|
||||
"org_grad": tensor to be compared from the original model
|
||||
"shard_grad": tensor to be compared from the sharded model
|
||||
"""
|
||||
for suffix, check_info in check_tensors.items():
|
||||
for idx, (suffix, check_info) in enumerate(check_tensors.items()):
|
||||
org_grad = check_info["org_grad"]
|
||||
shard_grad = check_info["shard_grad"]
|
||||
rtol = check_info["rtol"]
|
||||
|
Reference in New Issue
Block a user