[Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Edenzzzz
2024-08-16 13:56:38 +08:00
committed by GitHub
parent 887d2d579b
commit f5c84af0b0
50 changed files with 1870 additions and 326 deletions

View File

@@ -28,7 +28,7 @@ def all_reduce_mean(x: float, world_size: int) -> float:
return tensor.item()
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir, nsys=False):
class DummyProfiler:
def __init__(self):
self.step_number = 0
@@ -42,7 +42,29 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
def __exit__(self, exc_type, exc_value, traceback):
pass
class NsysProfiler:
def __init__(self, warmup_steps, active_steps):
self.step_number = 0
self.warmup_steps = warmup_steps
self.active_steps = active_steps
def step(self):
if self.step_number == self.warmup_steps:
torch.cuda.cudart().cudaProfilerStart()
elif self.step_number == self.warmup_steps + self.active_steps:
torch.cuda.cudart().cudaProfilerStop()
self.step_number += 1
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
if enable_flag:
if nsys:
return NsysProfiler(warmup_steps, active_steps)
return profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),