mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-04 08:35:26 +00:00
[devops] remove post commit ci (#5566)
* [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -16,17 +16,15 @@ def inference(args):
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
||||
if args.model == "test":
|
||||
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
|
||||
set_openmoe_args(config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_kernel=True)
|
||||
set_openmoe_args(
|
||||
config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=True
|
||||
)
|
||||
model = OpenMoeForCausalLM(config)
|
||||
else:
|
||||
config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}")
|
||||
set_openmoe_args(config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_kernel=False)
|
||||
set_openmoe_args(
|
||||
config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=False
|
||||
)
|
||||
model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config)
|
||||
model = model.eval().bfloat16()
|
||||
model = model.to(torch.cuda.current_device())
|
||||
|
Reference in New Issue
Block a user