mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[moe] init mixtral impl
This commit is contained in:
138
applications/ColossalMoE/infer.py
Normal file
138
applications/ColossalMoE/infer.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from colossal_moe.utils import load_model
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def parse_args():
|
||||
# basic settings
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-v0.1",
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["ep"],
|
||||
help="Parallel methos.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./outputs",
|
||||
help="The path of your saved model after finetuning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="bf16",
|
||||
choices=["fp32", "bf16", "fp16"],
|
||||
help="The mixed precision training.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
|
||||
# kernel
|
||||
parser.add_argument(
|
||||
"--use_kernel",
|
||||
action="store_true",
|
||||
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_layernorm_kernel",
|
||||
action="store_true",
|
||||
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
hybrid_dict = {
|
||||
"tp_size": 1,
|
||||
"custom_policy": MixtralForCausalLMPolicy(),
|
||||
"enable_fused_normalization": args.use_layernorm_kernel,
|
||||
"enable_jit_fused": args.use_kernel,
|
||||
"precision": args.precision,
|
||||
"checkpoint_io": MixtralMoECheckpointIO,
|
||||
"zero_stage": 1,
|
||||
}
|
||||
mgr_dict = {}
|
||||
if args.plugin == "ep":
|
||||
dp_size = dist.get_world_size()
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=dp_size,
|
||||
**mgr_dict,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
||||
# Build mixtral model
|
||||
config = MixtralConfig.from_pretrained(args.model_name)
|
||||
config.num_local_experts = 1 # dont change this. it will not affect model
|
||||
with skip_init():
|
||||
model = MixtralForCausalLM(config)
|
||||
model.num_experts = 8
|
||||
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
|
||||
model = model.to(get_current_device())
|
||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||
|
||||
# Replace moe
|
||||
with skip_init():
|
||||
replace_moe_layer(model)
|
||||
model.eval()
|
||||
coordinator.print_on_master(f"Finish replace moe module")
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, _, _, _, _ = booster.boost(model=model)
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# load ckpt
|
||||
load_model(args.model_name, model, booster)
|
||||
coordinator.print_on_master(f"Finish load ckpt")
|
||||
|
||||
text = ["Hello my name is", "1+1=?"]
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
|
||||
outputs = model.module.generate(**inputs, max_new_tokens=20)
|
||||
outputs = tokenizer.batch_decode(outputs)[0]
|
||||
print(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user