Fixed a bug in the inference frame

This commit is contained in:
yuehuayingxueluo
2023-12-26 21:34:27 +08:00
committed by FrankLeeeee
parent 86853a37d5
commit 62fd08ee44
8 changed files with 261 additions and 90 deletions

View File

@@ -70,7 +70,10 @@ def llama_model_forward(
seq_length = input_ids.shape[1]
device = input_ids.device
past_key_values_length = len(block_tables.shape[1])
if batch.is_prompts:
past_key_values_length = 0
else:
past_key_values_length = sequence_lengths[0].item() - 1
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
@@ -163,26 +166,17 @@ def llama_attn_forward(
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
block_size = k_cache.shape[-1]
k_cache.shape[-1]
memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size)
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
if is_prompts:
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
)
else:
attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
decoding_attention(
query_states,
k_cache,
v_cache,
block_tables,
sequence_lengths,
attn_output,
block_tables.shape[1],
block_size,
)
# if is_prompts:
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
# else:
# attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
attn_output = query_states
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -190,19 +184,3 @@ def llama_attn_forward(
attn_output = self.o_proj(attn_output)
return attn_output
def memcpy_to_block(key, value, k_cache, v_cache, block_tables, block_size):
block_table_list = block_tables.tolist()
batch_size, seq_len, num_heads, head_dim = key
reshape_key = key.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
reshape_value = value.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
if seq_len == 1:
for i in range(batch_size):
k_cache[block_table_list[i][-1], :] = reshape_key[i]
v_cache[block_table_list[i][-1], :] = reshape_value[i]
else:
for i in range(batch_size):
k_cache[block_table_list[i], :] = reshape_key[i]
v_cache[block_table_list[i], :] = reshape_value[i]

View File

@@ -1,7 +1,165 @@
from functools import partial
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaModel,
LlamaSdpaAttention,
)
from colossalai.inference.modeling.models.llama import (
llama_attn_forward,
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
# The code here just for test and will be modified later.
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
],
)
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
ColW8A8BFP32OFP32Linear,
RowW8A8B8O8Linear,
RowW8A8BFP32O32LinearSiLU,
RowW8A8BFP32OFP32Linear,
)
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=RowW8A8BFP32O32LinearSiLU,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=RowW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
],
)
self.shard_config._infer()
infer_forward = llama_causal_lm_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaForCausalLM
)
infer_forward = llama_model_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = llama_decoder_layer_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaAttention
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
)
return policy