[Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference
This commit is contained in:
Zhongkai Zhao
2023-11-14 17:26:59 +08:00
committed by GitHub
parent c6295c3381
commit 361cf63cb0
10 changed files with 27 additions and 18 deletions

View File

@@ -28,7 +28,9 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
# prepare data for generation

View File

@@ -30,7 +30,9 @@ def run_chatglm2_test(args):
model = model.half()
model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=1, do_sample=False)

View File

@@ -30,7 +30,9 @@ def run_llama_test(args):
model = model.half()
model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=1, do_sample=False)

View File

@@ -34,7 +34,9 @@ def bench_bloom(args):
model = model.half()
model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
@@ -46,7 +48,8 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "quant": "gptq"},
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

View File

@@ -33,7 +33,8 @@ def run_llama_test(args):
model_config = model.config
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "quant": "gptq"},
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)