mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[example] add profile util for llama
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import argparse
|
||||
import resource
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator
|
||||
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
@@ -76,6 +77,7 @@ def main():
|
||||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch()
|
||||
@@ -110,6 +112,7 @@ def main():
|
||||
extra_dp_size=args.extra_dp,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
max_prefetch=10,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
@@ -246,25 +249,37 @@ def main():
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||
)
|
||||
|
||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||
performance_evaluator.on_step_start(step)
|
||||
booster.execute_pipeline(
|
||||
data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False
|
||||
)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
else:
|
||||
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
with get_profile_context(
|
||||
args.profile,
|
||||
1,
|
||||
len(dataloader) - 1,
|
||||
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||
) as prof:
|
||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||
performance_evaluator.on_step_start(step)
|
||||
booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=False,
|
||||
)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
prof.step()
|
||||
else:
|
||||
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
prof.step()
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
Reference in New Issue
Block a user