mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[hotfix] Suport extra_kwargs in ShardConfig (#5031)
* [refactor]: replace inference args with extra_kwargs in ShardConfig * modify shardconfig * polish code * fix policy bug in llama * fix bug in auto policy * remove setattr in ShardConfig
This commit is contained in:
@@ -67,7 +67,9 @@ class Worker:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
|
||||
)
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
self.infer_engine = TPInferEngine(
|
||||
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
||||
)
|
||||
|
@@ -45,8 +45,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.inference_gptq:
|
||||
if self.shard_config.extra_kwargs.get("inference_gptq", False):
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
|
@@ -44,7 +44,7 @@ class TPInferEngine:
|
||||
>>> # define model and shard config for your inference
|
||||
>>> model = ...
|
||||
>>> generate_kwargs = ...
|
||||
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
|
||||
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, extra_kwargs={"inference_only": True})
|
||||
>>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
|
||||
"""
|
||||
@@ -181,7 +181,7 @@ class TPInferEngine:
|
||||
In further generation, use the sharded model instead of original model.
|
||||
"""
|
||||
# NOTE we will change to use an inference config later with additional attrs we want
|
||||
assert self.shard_config.inference_only is True
|
||||
assert self.shard_config.extra_kwargs["inference_only"] is True
|
||||
shardformer = ShardFormer(shard_config=self.shard_config)
|
||||
self._prepare_with_shard_config(shard_config=self.shard_config)
|
||||
self._shard_model_by(shardformer, model)
|
||||
@@ -203,10 +203,10 @@ class TPInferEngine:
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
inference_only=True,
|
||||
extra_kwargs={"inference_only": True},
|
||||
)
|
||||
else:
|
||||
shard_config.inference_only = True
|
||||
shard_config.extra_kwargs = {"inference_only": True}
|
||||
shard_config.pipeline_stage_manager = None
|
||||
if shard_config.enable_tensor_parallelism:
|
||||
self.tp_size = shard_config.tensor_parallel_size
|
||||
@@ -221,13 +221,11 @@ class TPInferEngine:
|
||||
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
|
||||
model_name = model.__class__.__name__
|
||||
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
||||
|
||||
model = model.model if self.shard_config.inference_gptq else model
|
||||
if self.shard_config.extra_kwargs.get("inference_gptq", False):
|
||||
model = model.model
|
||||
policy = get_autopolicy(model, shard_config=self.shard_config)
|
||||
|
||||
self.model, _ = shardformer.optimize(model, policy)
|
||||
|
||||
if self.shard_config.inference_gptq:
|
||||
if self.shard_config.extra_kwargs.get("inference_gptq", False):
|
||||
self._post_init_gptq_buffer(self.model)
|
||||
|
||||
self.model = self.model.cuda()
|
||||
|
@@ -4,7 +4,6 @@ import torch
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
|
||||
|
||||
@@ -38,35 +37,39 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.shard_config.inference_gptq:
|
||||
|
||||
if self.shard_config.extra_kwargs.get("inference_gptq", False):
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
|
||||
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 3}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={'split_num': 1}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 1}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={'split_num': 1}),
|
||||
])
|
||||
|
||||
policy[BloomBlock] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attention.hidden_size": self.model.config.hidden_size
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"self_attention.split_size": self.model.config.hidden_size
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 3},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
|
||||
),
|
||||
],
|
||||
)
|
||||
# NOTE set inference mode to shard config
|
||||
self.shard_config._infer()
|
||||
|
||||
|
@@ -13,6 +13,7 @@ from ..modeling.llama import LlamaInferenceForwards
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
||||
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
@@ -21,6 +22,7 @@ except:
|
||||
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
|
||||
@@ -36,7 +38,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.inference_gptq:
|
||||
if self.shard_config.extra_kwargs.get("inference_gptq", False):
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
|
Reference in New Issue
Block a user