mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[Feature] Zigzag Ring attention (#5905)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -28,6 +28,7 @@ warnings.filterwarnings("ignore")
|
||||
# Constants
|
||||
# ==============================
|
||||
|
||||
# We have lots of llamas for your choice!
|
||||
MODEL_CONFIGS = {
|
||||
"100m": LlamaConfig(
|
||||
max_position_embeddings=4096,
|
||||
@@ -36,6 +37,7 @@ MODEL_CONFIGS = {
|
||||
intermediate_size=2048,
|
||||
hidden_size=1024,
|
||||
),
|
||||
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
|
||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||
"13b": LlamaConfig(
|
||||
hidden_size=5120,
|
||||
@@ -68,9 +70,6 @@ def main():
|
||||
default="gemini",
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
|
||||
)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
|
||||
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
|
||||
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
|
||||
@@ -94,11 +93,24 @@ def main():
|
||||
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code", default=False)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||
parser.add_argument(
|
||||
"--nsys",
|
||||
action="store_true",
|
||||
help="Use nsys for profiling. \
|
||||
You should put something like this before colossalai launch: \
|
||||
nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
|
||||
)
|
||||
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
parser.add_argument("--overlap_allgather", action="store_true")
|
||||
parser.add_argument(
|
||||
"--sp_mode",
|
||||
default="all_to_all",
|
||||
choices=["all_to_all", "ring_attn", "ring", "split_gather"],
|
||||
help="Sequence parallelism mode",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch()
|
||||
@@ -195,12 +207,12 @@ def main():
|
||||
num_model_chunks=args.n_chunks,
|
||||
zero_stage=args.zero,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
enable_sequence_parallelism=args.sp > 1,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
overlap_p2p=args.overlap,
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
**hybrid_kwargs,
|
||||
@@ -218,7 +230,6 @@ def main():
|
||||
microbatch_size=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
overlap_p2p=args.overlap,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
@@ -295,6 +306,7 @@ def main():
|
||||
args.ignore_steps,
|
||||
1, # avoid creating massive log files
|
||||
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||
nsys=args.nsys,
|
||||
) as prof:
|
||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||
data_iter = iter(dataloader)
|
||||
@@ -320,13 +332,16 @@ def main():
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
del outputs # free memory
|
||||
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
prof.step()
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
@@ -17,7 +17,7 @@ limitations under the License.
|
||||
## OPT
|
||||
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
|
||||
|
||||
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
|
||||
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost.
|
||||
|
||||
|
||||
## Our Modifications
|
||||
|
@@ -28,7 +28,7 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
return tensor.item()
|
||||
|
||||
|
||||
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
||||
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir, nsys=False):
|
||||
class DummyProfiler:
|
||||
def __init__(self):
|
||||
self.step_number = 0
|
||||
@@ -42,7 +42,29 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
class NsysProfiler:
|
||||
def __init__(self, warmup_steps, active_steps):
|
||||
self.step_number = 0
|
||||
self.warmup_steps = warmup_steps
|
||||
self.active_steps = active_steps
|
||||
|
||||
def step(self):
|
||||
if self.step_number == self.warmup_steps:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
elif self.step_number == self.warmup_steps + self.active_steps:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
self.step_number += 1
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
if enable_flag:
|
||||
if nsys:
|
||||
return NsysProfiler(warmup_steps, active_steps)
|
||||
|
||||
return profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||
|
Reference in New Issue
Block a user