mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
Pass inference model shard configs for module init
Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -1,31 +1,23 @@
|
||||
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
||||
import itertools
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.utils import get_alibi_slopes
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
||||
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
||||
from colossalai.inference.utils import get_alibi_slopes
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
rms_layernorm,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
|
||||
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
@@ -69,6 +61,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
attn_oproj: ParallelModule = None,
|
||||
num_heads: int = None,
|
||||
hidden_size: int = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
process_group: ProcessGroup = None,
|
||||
helper_layout: Layout = None,
|
||||
):
|
||||
@@ -93,6 +86,9 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||
|
||||
self.helper_layout = helper_layout
|
||||
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
|
||||
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||
|
||||
self.alibi_slopes = None
|
||||
self.use_alibi_attn = False
|
||||
@@ -122,6 +118,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
attn_kproj_w = k_proj_w
|
||||
attn_vproj_w = v_proj_w
|
||||
attn_oproj = module.o_proj
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
|
||||
helper_layout = (
|
||||
module.W_pack.weight.dist_layout
|
||||
@@ -133,6 +130,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
attn_kproj_w=attn_kproj_w,
|
||||
attn_vproj_w=attn_vproj_w,
|
||||
attn_oproj=attn_oproj,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
process_group=process_group,
|
||||
@@ -201,7 +199,6 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
kv_seq_len: int = 0,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
high_precision: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
@@ -220,7 +217,6 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
@@ -233,7 +229,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
)
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
|
||||
attn_metadata = AttentionMetaData(
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
@@ -250,35 +246,31 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
output_tensor=output_tensor,
|
||||
use_spec_dec=is_verifier,
|
||||
use_alibi_attn=self.use_alibi_attn,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
)
|
||||
|
||||
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
|
||||
|
||||
if is_prompts: # prefilling stage
|
||||
pre_attention_backend.prefill(
|
||||
self.pre_attention_backend.prefill(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
high_precision=high_precision,
|
||||
)
|
||||
attn_output = attention_backend.prefill(
|
||||
attn_output = self.attention_backend.prefill(
|
||||
attn_metadata,
|
||||
token_nums=token_nums,
|
||||
)
|
||||
else: # decoding stage
|
||||
)
|
||||
else: # decoding stage
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
pre_attention_backend.decode(
|
||||
|
||||
self.pre_attention_backend.decode(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = attention_backend.decode(
|
||||
attn_metadata,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
attn_output = self.attention_backend.decode(
|
||||
attn_metadata,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
|
@@ -16,21 +16,13 @@ from transformers.models.llama.modeling_llama import (
|
||||
LlamaRMSNorm,
|
||||
)
|
||||
|
||||
from colossalai.inference.config import InputMetaData
|
||||
from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
||||
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
get_xine_cache,
|
||||
rms_layernorm,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.kernel.triton import get_xine_cache, rms_layernorm
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
|
||||
@@ -233,7 +225,6 @@ def llama_decoder_layer_forward(
|
||||
kv_seq_len=kv_seq_len,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
high_precision=high_precision,
|
||||
)
|
||||
@@ -397,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
attn_vproj_w: torch.Tensor = None,
|
||||
attn_oproj: ParallelModule = None,
|
||||
process_group: ProcessGroup = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
num_heads: int = None,
|
||||
hidden_size: int = None,
|
||||
num_key_value_heads: int = None,
|
||||
@@ -428,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||
|
||||
if self.num_heads == self.num_key_value_heads:
|
||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||
@@ -457,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
attn_vproj_w = module.v_proj.weight
|
||||
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
|
||||
attn_oproj = module.o_proj
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
|
||||
attn_layer = NopadLlamaAttention(
|
||||
config=config,
|
||||
@@ -466,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
attn_vproj_w=attn_vproj_w,
|
||||
attn_oproj=attn_oproj,
|
||||
process_group=process_group,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
num_key_value_heads=module.num_key_value_heads,
|
||||
@@ -527,7 +524,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
)
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
|
||||
attn_metadata = AttentionMetaData(
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
@@ -544,38 +541,34 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
output_tensor=output_tensor,
|
||||
use_spec_dec=is_verifier,
|
||||
use_alibi_attn=False,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
)
|
||||
|
||||
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
|
||||
|
||||
if is_prompts: # prefilling stage
|
||||
pre_attention_backend.prefill(
|
||||
self.pre_attention_backend.prefill(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
high_precision=high_precision,
|
||||
)
|
||||
attn_output = attention_backend.prefill(
|
||||
attn_output = self.attention_backend.prefill(
|
||||
attn_metadata,
|
||||
token_nums=token_nums,
|
||||
)
|
||||
else: # decoding stage
|
||||
)
|
||||
else: # decoding stage
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
pre_attention_backend.decode(
|
||||
|
||||
self.pre_attention_backend.decode(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = attention_backend.decode(
|
||||
attn_metadata,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
attn_output = self.attention_backend.decode(
|
||||
attn_metadata,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
num_key_value_groups=self.num_key_value_groups,
|
||||
q_len=q_len,
|
||||
)
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
@@ -633,4 +626,3 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|
||||
|
Reference in New Issue
Block a user