diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index e2114d43b..a926bd18a 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -142,7 +142,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny - timeout-minutes: 60 + timeout-minutes: 100 defaults: run: shell: bash diff --git a/colossalai/inference/hybridengine/polices/llama.py b/colossalai/inference/hybridengine/polices/llama.py index 3cdfc0173..11517d7e8 100644 --- a/colossalai/inference/hybridengine/polices/llama.py +++ b/colossalai/inference/hybridengine/polices/llama.py @@ -51,7 +51,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): "self_attn.num_key_value_heads": self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size, } - if self.shard_config.quant == "gptq": + if self.shard_config.extra_kwargs.get("quant", None) == "gptq": from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear policy[LlamaDecoderLayer] = ModulePolicyDescription( @@ -95,7 +95,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): ], ) - elif self.shard_config.quant == "smoothquant": + elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant": from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer from colossalai.inference.quant.smoothquant.models.parallel_linear import ( ColW8A8BFP32OFP32Linear, diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index cabd10bba..cf06eecd3 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -81,7 +81,7 @@ Following are the description `ShardConfig`'s arguments: - `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. -- `inference_only`: Whether only doing forward passing. Defaults to False. +- `extra_kwargs`: A dict to store extra kwargs for ShardFomer. ### Write your own policy @@ -185,8 +185,8 @@ class ShardConfig: # Some possible future config fields tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode - inference_only: bool # only inject inference-suitable sharding policy use_flash_attention: bool # whether to use flash attention to speed up attention + extra_kwargs: Dict[str, Any] # extra kwargs for the shardformer ``` ### Policy diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 3014f1cf3..c46934fb0 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy :class:`Policy`: The auto policy for the model """ full_name = _fullname(model) - if shard_config.inference_only: + inference_only = shard_config.extra_kwargs.get("inference_only", None) + if inference_only: policy_location = _INFER_POLICY_LIST.get(full_name, None) else: policy_location = _POLICY_LIST.get(full_name, None) @@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" ) else: - policy = import_policy(policy_location, shard_config.inference_only) + policy = import_policy(policy_location, inference_only) return policy() diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index c7d63c234..64e5489e7 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass -from typing import Optional +from dataclasses import dataclass, field +from typing import Any, Dict, Optional import torch.distributed as dist from torch.distributed import ProcessGroup @@ -33,11 +33,9 @@ class ShardConfig: enable_flash_attention: bool = False enable_jit_fused: bool = False enable_all_optimization: bool = False - inference_only: bool = False - inference_gptq: bool = False enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False - quant: str = None + extra_kwargs: Dict[str, Any] = field(default_factory=dict) # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py index 054641f6e..5c7af6ed5 100644 --- a/examples/inference/bench_bloom.py +++ b/examples/inference/bench_bloom.py @@ -28,7 +28,9 @@ def bench_bloom(args): # init TPInferEngine and shard the original model # To benchmark torch original, comment out the line of optimizing model - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) # prepare data for generation diff --git a/examples/inference/bench_chatglm2.py b/examples/inference/bench_chatglm2.py index f3678d29f..3892d98ba 100644 --- a/examples/inference/bench_chatglm2.py +++ b/examples/inference/bench_chatglm2.py @@ -30,7 +30,9 @@ def run_chatglm2_test(args): model = model.half() model.config - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) generate_kwargs = dict(max_new_tokens=1, do_sample=False) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 56bf062e2..4db32c71a 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -30,7 +30,9 @@ def run_llama_test(args): model = model.half() model.config - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) generate_kwargs = dict(max_new_tokens=1, do_sample=False) diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index cfa317137..a6e07b98c 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -34,7 +34,9 @@ def bench_bloom(args): model = model.half() model_config = model.config - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) @@ -46,7 +48,8 @@ def bench_bloom(args): # init TPInferEngine and shard the original model # To benchmark torch original, comment out the line of optimizing model shard_config = ShardConfig( - enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True + enable_tensor_parallelism=True if args.tp_size > 1 else False, + extra_kwargs={"inference_only": True, "quant": "gptq"}, ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index 35a6049ad..61da7ca24 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -33,7 +33,8 @@ def run_llama_test(args): model_config = model.config shard_config = ShardConfig( - enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True + enable_tensor_parallelism=True if args.tp_size > 1 else False, + extra_kwargs={"inference_only": True, "quant": "gptq"}, ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)