[Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference
This commit is contained in:
Zhongkai Zhao 2023-11-14 17:26:59 +08:00 committed by GitHub
parent c6295c3381
commit 361cf63cb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 27 additions and 18 deletions

View File

@ -142,7 +142,7 @@ jobs:
container: container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 60 timeout-minutes: 100
defaults: defaults:
run: run:
shell: bash shell: bash

View File

@ -51,7 +51,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size, // self.shard_config.tensor_parallel_size,
} }
if self.shard_config.quant == "gptq": if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[LlamaDecoderLayer] = ModulePolicyDescription( policy[LlamaDecoderLayer] = ModulePolicyDescription(
@ -95,7 +95,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
], ],
) )
elif self.shard_config.quant == "smoothquant": elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
from colossalai.inference.quant.smoothquant.models.parallel_linear import ( from colossalai.inference.quant.smoothquant.models.parallel_linear import (
ColW8A8BFP32OFP32Linear, ColW8A8BFP32OFP32Linear,

View File

@ -81,7 +81,7 @@ Following are the description `ShardConfig`'s arguments:
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. - `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
- `inference_only`: Whether only doing forward passing. Defaults to False. - `extra_kwargs`: A dict to store extra kwargs for ShardFomer.
### Write your own policy ### Write your own policy
@ -185,8 +185,8 @@ class ShardConfig:
# Some possible future config fields # Some possible future config fields
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
use_flash_attention: bool # whether to use flash attention to speed up attention use_flash_attention: bool # whether to use flash attention to speed up attention
extra_kwargs: Dict[str, Any] # extra kwargs for the shardformer
``` ```
### Policy ### Policy

View File

@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
:class:`Policy`: The auto policy for the model :class:`Policy`: The auto policy for the model
""" """
full_name = _fullname(model) full_name = _fullname(model)
if shard_config.inference_only: inference_only = shard_config.extra_kwargs.get("inference_only", None)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None) policy_location = _INFER_POLICY_LIST.get(full_name, None)
else: else:
policy_location = _POLICY_LIST.get(full_name, None) policy_location = _POLICY_LIST.get(full_name, None)
@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
) )
else: else:
policy = import_policy(policy_location, shard_config.inference_only) policy = import_policy(policy_location, inference_only)
return policy() return policy()

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Any, Dict, Optional
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -33,11 +33,9 @@ class ShardConfig:
enable_flash_attention: bool = False enable_flash_attention: bool = False
enable_jit_fused: bool = False enable_jit_fused: bool = False
enable_all_optimization: bool = False enable_all_optimization: bool = False
inference_only: bool = False
inference_gptq: bool = False
enable_sequence_parallelism: bool = False enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False enable_sequence_overlap: bool = False
quant: str = None extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int # pipeline_parallel_size: int
# data_parallel_size: int # data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']

View File

@ -28,7 +28,9 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model # init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model # To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
# prepare data for generation # prepare data for generation

View File

@ -30,7 +30,9 @@ def run_chatglm2_test(args):
model = model.half() model = model.half()
model.config model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=1, do_sample=False) generate_kwargs = dict(max_new_tokens=1, do_sample=False)

View File

@ -30,7 +30,9 @@ def run_llama_test(args):
model = model.half() model = model.half()
model.config model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=1, do_sample=False) generate_kwargs = dict(max_new_tokens=1, do_sample=False)

View File

@ -34,7 +34,9 @@ def bench_bloom(args):
model = model.half() model = model.half()
model_config = model.config model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
@ -46,7 +48,8 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model # init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model # To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "quant": "gptq"},
) )
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

View File

@ -33,7 +33,8 @@ def run_llama_test(args):
model_config = model.config model_config = model.config
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "quant": "gptq"},
) )
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)