mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[Feature] Split cross-entropy computation in SP (#5959)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -28,7 +28,7 @@ MODEL_CONFIGS = {
|
||||
"118M": GPT2Config(activation_function="gelu"),
|
||||
"338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"),
|
||||
"738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"),
|
||||
"6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=4096, activation_function="gelu"),
|
||||
"6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=32768, activation_function="gelu"),
|
||||
}
|
||||
|
||||
|
||||
@@ -60,6 +60,8 @@ def main():
|
||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
|
||||
parser.add_argument("--sp_mode", type=str, default="ring_attn", help="Sequence parallel mode")
|
||||
parser.add_argument("--mbs", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=0)
|
||||
parser.add_argument("--pp_style", type=str, default="1f1b")
|
||||
@@ -129,6 +131,9 @@ def main():
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style=args.pp_style,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
enable_sequence_parallelism=True,
|
||||
zero_stage=args.zero,
|
||||
num_model_chunks=args.num_model_chunks,
|
||||
enable_all_optimization=True,
|
||||
@@ -214,6 +219,8 @@ def main():
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
del outputs
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
@@ -6,7 +6,6 @@ import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
@@ -22,8 +21,11 @@ def divide(x: float, y: float) -> float:
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=get_accelerator().get_current_device())
|
||||
dist.all_reduce(tensor)
|
||||
|
||||
# Use CPU tensor to avoid OOM/weird NCCl error
|
||||
gloo_group = dist.new_group(backend="gloo")
|
||||
tensor = torch.tensor([x], device="cpu")
|
||||
dist.all_reduce(tensor, group=gloo_group)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
Reference in New Issue
Block a user