mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 16:28:10 +00:00
[feat] support zbv in mixtral benchmark;
This commit is contained in:
parent
cc500b3e25
commit
3f5bec8dc4
@ -11,6 +11,7 @@ from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
@ -85,7 +87,7 @@ def main():
|
||||
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||
parser.add_argument(
|
||||
@ -120,7 +122,7 @@ def main():
|
||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
),
|
||||
"num_layers_per_stage": [19, 20, 20, 21],
|
||||
"pp_style": "interleaved",
|
||||
# "pp_style": "interleaved",
|
||||
}
|
||||
if args.custom_ckpt
|
||||
else {}
|
||||
@ -129,7 +131,29 @@ def main():
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
|
||||
if args.plugin == "3d":
|
||||
if args.pp_style == "zbv":
|
||||
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
|
||||
mem_w = -32 * config.hidden_size
|
||||
mem_b = -mem_w - mem_f
|
||||
scheduler_nodes = PipelineGraph(
|
||||
n_stage=args.pp,
|
||||
n_micro=args.batch_size // args.mbs,
|
||||
f_cost=1000,
|
||||
b_cost=1000,
|
||||
w_cost=1000,
|
||||
c_cost=1,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
).get_v_schedule()
|
||||
else:
|
||||
scheduler_nodes = None
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
ep_size=args.ep,
|
||||
tp_size=args.tp,
|
||||
@ -148,6 +172,7 @@ def main():
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
scheduler_nodes=scheduler_nodes,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user