mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-12 08:55:43 +00:00
* benchmark gpt2 * fix fix fix fix * [doc] fix typo in Colossal-LLaMA-2/README.md (#5247) * [workflow] fixed build CI (#5240) * [workflow] fixed build CI * polish * polish * polish * polish * polish * [ci] fixed booster test (#5251) * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed ddp test (#5254) * [ci] fixed ddp test * polish * fix typo in applications/ColossalEval/README.md (#5250) * [ci] fix shardformer tests. (#5255) * fix ci fix * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [doc] fix doc typo (#5256) * [doc] fix annotation display * [doc] fix llama2 doc * [hotfix]: add pp sanity check and fix mbs arg (#5268) * fix: fix misleading mbs arg * feat: add pp sanity check * fix: fix 1f1b sanity check * [workflow] fixed incomplete bash command (#5272) * [workflow] fixed oom tests (#5275) * [workflow] fixed oom tests * polish * polish * polish * [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276) * fix ci fix * fix test * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests * fix --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [shardformer] hybridparallelplugin support gradients accumulation. (#5246) * support gradients acc fix fix fix fix fix fix fix fix fix fix fix fix fix * fix fix * fix fix fix * [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230) * fix auto loading gpt2 tokenizer (#5279) * [doc] add llama2-13B disyplay (#5285) * Update README.md * fix 13b typo --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> * fix llama pretrain (#5287) * fix * fix * fix fix * fix fix fix * fix fix * benchmark gpt2 * fix fix fix fix * [workflow] fixed build CI (#5240) * [workflow] fixed build CI * polish * polish * polish * polish * polish * [ci] fixed booster test (#5251) * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed booster test * fix fix * fix fix fix * fix * fix fix fix fix fix * fix * Update shardformer.py --------- Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com> Co-authored-by: Desperado-Jia <502205863@qq.com>
121 lines
4.3 KiB
Python
121 lines
4.3 KiB
Python
from time import time
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import Tensor
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.cluster import DistCoordinator
|
|
|
|
|
|
def divide(x: float, y: float) -> float:
|
|
if y == 0:
|
|
return float("inf")
|
|
elif y == float("inf"):
|
|
return float("nan")
|
|
return x / y
|
|
|
|
|
|
@torch.no_grad()
|
|
def all_reduce_mean(x: float, world_size: int) -> float:
|
|
if world_size == 1:
|
|
return x
|
|
tensor = torch.tensor([x], device=get_accelerator().get_current_device())
|
|
dist.all_reduce(tensor)
|
|
tensor = tensor / world_size
|
|
return tensor.item()
|
|
|
|
|
|
class Timer:
|
|
def __init__(self) -> None:
|
|
self.start_time: Optional[float] = None
|
|
self.duration: float = 0.0
|
|
|
|
def start(self) -> None:
|
|
self.start_time = time()
|
|
|
|
def end(self) -> None:
|
|
assert self.start_time is not None
|
|
self.duration += time() - self.start_time
|
|
self.start_time = None
|
|
|
|
def reset(self) -> None:
|
|
self.duration = 0.0
|
|
|
|
|
|
class PerformanceEvaluator:
|
|
"""
|
|
Callback for valuate the performance of the model.
|
|
Args:
|
|
actor_num_params: The number of parameters of the actor model.
|
|
critic_num_params: The number of parameters of the critic model.
|
|
initial_model_num_params: The number of parameters of the initial model.
|
|
reward_model_num_params: The number of parameters of the reward model.
|
|
enable_grad_checkpoint: Whether to enable gradient checkpointing.
|
|
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_numel: int,
|
|
num_layers: int,
|
|
hidden_size: int,
|
|
vocab_size: int,
|
|
enable_grad_checkpoint: bool = False,
|
|
ignore_steps: int = 0,
|
|
dp_world_size: Optional[int] = None,
|
|
) -> None:
|
|
self.model_numel = model_numel
|
|
self.enable_grad_checkpoint = enable_grad_checkpoint
|
|
self.ignore_steps = ignore_steps
|
|
self.num_layers = num_layers
|
|
self.hidden_size = hidden_size
|
|
self.vocab_size = vocab_size
|
|
|
|
self.coordinator = DistCoordinator()
|
|
self.dp_world_size = dp_world_size or self.coordinator.world_size
|
|
self.disable: bool = False
|
|
self.timer = Timer()
|
|
self.num_samples: int = 0
|
|
self.flop_megatron = 0
|
|
self.flop: int = 0
|
|
|
|
def on_step_start(self, step: int) -> None:
|
|
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
|
if self.disable:
|
|
return
|
|
get_accelerator().synchronize()
|
|
self.timer.start()
|
|
|
|
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
|
if self.disable:
|
|
return
|
|
get_accelerator().synchronize()
|
|
self.timer.end()
|
|
|
|
batch_size, seq_len = input_ids.shape
|
|
|
|
self.num_samples += batch_size
|
|
checkpoint_activations_factor = 3 + int(self.enable_grad_checkpoint)
|
|
self.flop_megatron += (
|
|
24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2)
|
|
) * (
|
|
1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size))
|
|
)
|
|
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
|
|
|
def on_fit_end(self) -> None:
|
|
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
|
|
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
|
mp_world_size = self.coordinator.world_size // self.dp_world_size
|
|
avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
|
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
|
self.coordinator.print_on_master(
|
|
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
|
f"avg_throughput: {avg_throughput}"
|
|
)
|
|
self.coordinator.print_on_master(
|
|
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
|
|
)
|