mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[hotfix] Suport extra_kwargs in ShardConfig (#5031)
* [refactor]: replace inference args with extra_kwargs in ShardConfig * modify shardconfig * polish code * fix policy bug in llama * fix bug in auto policy * remove setattr in ShardConfig
This commit is contained in:
@@ -28,7 +28,9 @@ def bench_bloom(args):
|
||||
|
||||
# init TPInferEngine and shard the original model
|
||||
# To benchmark torch original, comment out the line of optimizing model
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
# prepare data for generation
|
||||
|
@@ -30,7 +30,9 @@ def run_chatglm2_test(args):
|
||||
model = model.half()
|
||||
model.config
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
||||
|
@@ -30,7 +30,9 @@ def run_llama_test(args):
|
||||
model = model.half()
|
||||
model.config
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
||||
|
@@ -34,7 +34,9 @@ def bench_bloom(args):
|
||||
model = model.half()
|
||||
|
||||
model_config = model.config
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||
|
||||
@@ -46,7 +48,8 @@ def bench_bloom(args):
|
||||
# init TPInferEngine and shard the original model
|
||||
# To benchmark torch original, comment out the line of optimizing model
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||
extra_kwargs={"inference_only": True, "inference_gptq": True},
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
|
@@ -33,7 +33,8 @@ def run_llama_test(args):
|
||||
|
||||
model_config = model.config
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||
extra_kwargs={"inference_only": True, "inference_gptq": True},
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
|
@@ -68,7 +68,9 @@ class Worker:
|
||||
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
self.infer_engine = TPInferEngine(
|
||||
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
||||
)
|
||||
|
@@ -100,7 +100,9 @@ class ColossalInferenceHandler(BaseHandler, ABC):
|
||||
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
|
||||
logger.info("Initializing TPInferEngine ...")
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if self.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
self.infer_engine = TPInferEngine(
|
||||
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
||||
)
|
||||
|
Reference in New Issue
Block a user