1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-04-29 12:15:39 +00:00
ColossalAI/examples/language/model_utils.py
flybird11111 29695cf70c
[example]add gpt2 benchmark example script. ()
* benchmark gpt2

* fix

fix

fix

fix

* [doc] fix typo in Colossal-LLaMA-2/README.md ()

* [workflow] fixed build CI ()

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test ()

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed ddp test ()

* [ci] fixed ddp test

* polish

* fix typo in  applications/ColossalEval/README.md ()

* [ci] fix shardformer tests. ()

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [doc] fix doc typo ()

* [doc] fix annotation display

* [doc] fix llama2 doc

* [hotfix]: add pp sanity check and fix mbs arg ()

* fix: fix misleading mbs arg

* feat: add pp sanity check

* fix: fix 1f1b sanity check

* [workflow] fixed incomplete bash command ()

* [workflow] fixed oom tests ()

* [workflow] fixed oom tests

* polish

* polish

* polish

* [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py ()

* fix ci

fix

* fix test

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

* fix

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [shardformer] hybridparallelplugin support gradients accumulation. ()

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix

* [hotfix] Fix ShardFormer test execution path when using sequence parallelism ()

* fix auto loading gpt2 tokenizer ()

* [doc] add llama2-13B disyplay ()

* Update README.md

* fix 13b typo

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* fix llama pretrain ()

* fix

* fix

* fix

fix

* fix

fix

fix

* fix

fix

* benchmark gpt2

* fix

fix

fix

fix

* [workflow] fixed build CI ()

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test ()

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* fix

fix

* fix

fix

fix

* fix

* fix

fix

fix

fix

fix

* fix

* Update shardformer.py

---------

Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com>
Co-authored-by: Desperado-Jia <502205863@qq.com>
2024-03-04 16:18:13 +08:00

33 lines
713 B
Python

from contextlib import contextmanager
import torch
import torch.nn as nn
@contextmanager
def low_precision_init(target_dtype: torch.dtype = torch.float16):
dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(target_dtype)
yield
finally:
torch.set_default_dtype(dtype)
def get_model_numel(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"