[hotfix] Suport extra_kwargs in ShardConfig (#5031)

* [refactor]: replace inference args with extra_kwargs in ShardConfig

* modify shardconfig

* polish code

* fix policy bug in llama

* fix bug in auto policy

* remove setattr in ShardConfig
This commit is contained in:
Zhongkai Zhao
2023-11-10 10:49:50 +08:00
committed by GitHub
parent 576a2f7b10
commit 70885d707d
23 changed files with 98 additions and 77 deletions

View File

@@ -67,7 +67,9 @@ class Worker:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
)
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)

View File

@@ -45,8 +45,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.inference_gptq:
if self.shard_config.extra_kwargs.get("inference_gptq", False):
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
decoder_attribute_replacement = {

View File

@@ -44,7 +44,7 @@ class TPInferEngine:
>>> # define model and shard config for your inference
>>> model = ...
>>> generate_kwargs = ...
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, extra_kwargs={"inference_only": True})
>>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
"""
@@ -181,7 +181,7 @@ class TPInferEngine:
In further generation, use the sharded model instead of original model.
"""
# NOTE we will change to use an inference config later with additional attrs we want
assert self.shard_config.inference_only is True
assert self.shard_config.extra_kwargs["inference_only"] is True
shardformer = ShardFormer(shard_config=self.shard_config)
self._prepare_with_shard_config(shard_config=self.shard_config)
self._shard_model_by(shardformer, model)
@@ -203,10 +203,10 @@ class TPInferEngine:
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
inference_only=True,
extra_kwargs={"inference_only": True},
)
else:
shard_config.inference_only = True
shard_config.extra_kwargs = {"inference_only": True}
shard_config.pipeline_stage_manager = None
if shard_config.enable_tensor_parallelism:
self.tp_size = shard_config.tensor_parallel_size
@@ -221,13 +221,11 @@ class TPInferEngine:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
model = model.model if self.shard_config.inference_gptq else model
if self.shard_config.extra_kwargs.get("inference_gptq", False):
model = model.model
policy = get_autopolicy(model, shard_config=self.shard_config)
self.model, _ = shardformer.optimize(model, policy)
if self.shard_config.inference_gptq:
if self.shard_config.extra_kwargs.get("inference_gptq", False):
self._post_init_gptq_buffer(self.model)
self.model = self.model.cuda()

View File

@@ -4,7 +4,6 @@ import torch
from torch.nn import LayerNorm
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
@@ -38,35 +37,39 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
if self.shard_config.inference_gptq:
if self.shard_config.extra_kwargs.get("inference_gptq", False):
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 3}),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1}),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}),
])
policy[BloomBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 3},
),
SubModuleReplacementDescription(
suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
),
],
)
# NOTE set inference mode to shard config
self.shard_config._infer()

View File

@@ -13,6 +13,7 @@ from ..modeling.llama import LlamaInferenceForwards
try:
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
@@ -21,6 +22,7 @@ except:
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
@@ -36,7 +38,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.inference_gptq:
if self.shard_config.extra_kwargs.get("inference_gptq", False):
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
decoder_attribute_replacement = {

View File

@@ -81,8 +81,6 @@ Following are the description `ShardConfig`'s arguments:
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
- `inference_only`: Whether only doing forward passing. Defaults to False.
### Write your own policy
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design).
@@ -185,7 +183,6 @@ class ShardConfig:
# Some possible future config fields
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
use_flash_attention: bool # whether to use flash attention to speed up attention
```

View File

@@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
if shard_config.inference_only:
inference_only = shard_config.extra_kwargs.get("inference_only", False)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
@@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location, shard_config.inference_only)
policy = import_policy(policy_location, inference_only)
return policy()

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch.distributed as dist
from torch.distributed import ProcessGroup
@@ -24,7 +24,6 @@ class ShardConfig:
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
inference_only (bool): Whether only doing forward passing. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
@@ -33,10 +32,9 @@ class ShardConfig:
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_all_optimization: bool = False
inference_only: bool = False
inference_gptq: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
extra_kwargs: Dict[str, bool] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@@ -77,4 +75,3 @@ class ShardConfig:
Set default params for inference.
"""
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
pass