[hotfix] Suport extra_kwargs in ShardConfig (#5031)

* [refactor]: replace inference args with extra_kwargs in ShardConfig

* modify shardconfig

* polish code

* fix policy bug in llama

* fix bug in auto policy

* remove setattr in ShardConfig
This commit is contained in:
Zhongkai Zhao
2023-11-10 10:49:50 +08:00
committed by GitHub
parent 576a2f7b10
commit 70885d707d
23 changed files with 98 additions and 77 deletions

View File

@@ -28,7 +28,9 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
# prepare data for generation

View File

@@ -30,7 +30,9 @@ def run_chatglm2_test(args):
model = model.half()
model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=1, do_sample=False)

View File

@@ -30,7 +30,9 @@ def run_llama_test(args):
model = model.half()
model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=1, do_sample=False)

View File

@@ -34,7 +34,9 @@ def bench_bloom(args):
model = model.half()
model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
@@ -46,7 +48,8 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "inference_gptq": True},
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

View File

@@ -33,7 +33,8 @@ def run_llama_test(args):
model_config = model.config
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "inference_gptq": True},
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

View File

@@ -68,7 +68,9 @@ class Worker:
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
)
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)

View File

@@ -100,7 +100,9 @@ class ColossalInferenceHandler(BaseHandler, ABC):
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
logger.info("Initializing TPInferEngine ...")
shard_config = ShardConfig(enable_tensor_parallelism=True if self.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)