Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
char-1ee
2024-06-07 08:28:19 +00:00
parent eec77e5702
commit 5f398fc000
11 changed files with 238 additions and 136 deletions

View File

@@ -16,21 +16,13 @@ from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
)
from colossalai.inference.config import InputMetaData
from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
rms_layernorm,
rotary_embedding,
)
from colossalai.kernel.triton import get_xine_cache, rms_layernorm
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
@@ -233,7 +225,6 @@ def llama_decoder_layer_forward(
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
)
@@ -397,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w: torch.Tensor = None,
attn_oproj: ParallelModule = None,
process_group: ProcessGroup = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
num_heads: int = None,
hidden_size: int = None,
num_key_value_heads: int = None,
@@ -428,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
if self.num_heads == self.num_key_value_heads:
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
@@ -457,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w = module.v_proj.weight
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
attn_oproj = module.o_proj
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
attn_layer = NopadLlamaAttention(
config=config,
@@ -466,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w=attn_vproj_w,
attn_oproj=attn_oproj,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
num_key_value_heads=module.num_key_value_heads,
@@ -527,7 +524,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
)
block_size = k_cache.size(-2)
attn_metadata = AttentionMetaData(
query_states=query_states,
key_states=key_states,
@@ -544,38 +541,34 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
output_tensor=output_tensor,
use_spec_dec=is_verifier,
use_alibi_attn=False,
use_cuda_kernel=use_cuda_kernel,
)
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
if is_prompts: # prefilling stage
pre_attention_backend.prefill(
self.pre_attention_backend.prefill(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
high_precision=high_precision,
)
attn_output = attention_backend.prefill(
attn_output = self.attention_backend.prefill(
attn_metadata,
token_nums=token_nums,
)
else: # decoding stage
)
else: # decoding stage
q_len = tokens_to_verify + 1 if is_verifier else 1
pre_attention_backend.decode(
self.pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = attention_backend.decode(
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
attn_output = self.attention_backend.decode(
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
num_key_value_groups=self.num_key_value_groups,
q_len=q_len,
)
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output)
@@ -633,4 +626,3 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
def extra_repr(self) -> str:
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"