mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[Refactor] refactor policy search and quant type controlling in inference (#5035)
* [Refactor] refactor policy search and quant type controling in inference
This commit is contained in:
@@ -28,7 +28,9 @@ def bench_bloom(args):
|
||||
|
||||
# init TPInferEngine and shard the original model
|
||||
# To benchmark torch original, comment out the line of optimizing model
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
# prepare data for generation
|
||||
|
@@ -30,7 +30,9 @@ def run_chatglm2_test(args):
|
||||
model = model.half()
|
||||
model.config
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
||||
|
@@ -30,7 +30,9 @@ def run_llama_test(args):
|
||||
model = model.half()
|
||||
model.config
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
||||
|
@@ -34,7 +34,9 @@ def bench_bloom(args):
|
||||
model = model.half()
|
||||
|
||||
model_config = model.config
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||
|
||||
@@ -46,7 +48,8 @@ def bench_bloom(args):
|
||||
# init TPInferEngine and shard the original model
|
||||
# To benchmark torch original, comment out the line of optimizing model
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||
extra_kwargs={"inference_only": True, "quant": "gptq"},
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
|
@@ -33,7 +33,8 @@ def run_llama_test(args):
|
||||
|
||||
model_config = model.config
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||
extra_kwargs={"inference_only": True, "quant": "gptq"},
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
|
Reference in New Issue
Block a user