mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Feature] optimize PP overlap (#5735)
* update to fully overlap, still debugging * improve interface * fixed deadlock bug * debug NaN loss * (experimental) use one comm group for send_fw_recv_fw to fix NaN * cleaned up interfaces; use one batch p2p for all * clean up; removed the double p2p batch case * p2p test passsed * improve overlap: send fwd before backward * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tentatively use 2 p2p batches * remove two p2p batches * fix typos * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove pp.sh --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root <root@notebook-c55824c0-7742-45e8-9591-c855bb77ad29-0.notebook-c55824c0-7742-45e8-9591-c855bb77ad29.colossal-ai.svc.cluster.local>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import argparse
|
||||
import resource
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||
@@ -21,11 +23,19 @@ from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
# ==============================
|
||||
# Constants
|
||||
# ==============================
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"100m": LlamaConfig(
|
||||
max_position_embeddings=4096,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=32,
|
||||
intermediate_size=2048,
|
||||
hidden_size=1024,
|
||||
),
|
||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||
"13b": LlamaConfig(
|
||||
hidden_size=5120,
|
||||
@@ -58,6 +68,9 @@ def main():
|
||||
default="gemini",
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
|
||||
)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
|
||||
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
|
||||
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
|
||||
@@ -78,11 +91,13 @@ def main():
|
||||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
|
||||
parser.add_argument(
|
||||
"--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False
|
||||
)
|
||||
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code", default=False)
|
||||
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch()
|
||||
@@ -98,6 +113,7 @@ def main():
|
||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
),
|
||||
"num_layers_per_stage": [19, 20, 20, 21],
|
||||
"pp_style": "interleaved",
|
||||
}
|
||||
if args.custom_ckpt
|
||||
else {}
|
||||
@@ -174,6 +190,8 @@ def main():
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style=args.pp_style,
|
||||
num_model_chunks=args.n_chunks,
|
||||
zero_stage=args.zero,
|
||||
sp_size=args.sp,
|
||||
enable_sequence_parallelism=args.sp > 1,
|
||||
@@ -182,12 +200,16 @@ def main():
|
||||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
dp_outside=False,
|
||||
overlap_p2p=args.overlap,
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style=args.pp_style,
|
||||
num_model_chunks=args.n_chunks,
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
@@ -195,6 +217,7 @@ def main():
|
||||
microbatch_size=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
overlap_p2p=args.overlap,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
@@ -210,10 +233,11 @@ def main():
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
torch.cuda.manual_seed(42)
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
|
||||
|
||||
# ==============================
|
||||
# Initialize Model and Optimizer
|
||||
@@ -251,6 +275,7 @@ def main():
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
|
||||
@@ -269,15 +294,19 @@ def main():
|
||||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||
performance_evaluator.on_step_start(step)
|
||||
booster.execute_pipeline(
|
||||
outputs = booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=False,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
prof.step()
|
||||
else:
|
||||
@@ -288,6 +317,7 @@ def main():
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
prof.step()
|
||||
|
||||
|
Reference in New Issue
Block a user