[Inference]Add Nopadding Llama Modeling (#5327)

* add nopadding llama modeling

* add nopadding_llama.py

* rm unused codes

* fix bugs in test_xine_copy.py

* fix code style
This commit is contained in:
yuehuayingxueluo 2024-01-30 10:31:46 +08:00 committed by GitHub
parent c7c104cb7c
commit e8f0642f28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 386 additions and 49 deletions

View File

@ -32,6 +32,7 @@ class InferenceConfig:
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
when the actual value exceeds this ratio. when the actual value exceeds this ratio.
pad_input: Whether to pad all inputs to the max length.
quant_mode (Optional[str]): Quantization mode. quant_mode (Optional[str]): Quantization mode.
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use. revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
""" """
@ -49,6 +50,7 @@ class InferenceConfig:
beam_width: int = 1 beam_width: int = 1
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
prefill_ratio: Optional[float] = 1.2 prefill_ratio: Optional[float] = 1.2
pad_input: bool = False
quant_mode: Optional[str] = None quant_mode: Optional[str] = None
revision: Optional[str] = None revision: Optional[str] = None

View File

@ -57,7 +57,11 @@ class InferenceEngine:
model.to(self.dtype) model.to(self.dtype)
if model_policy is None: if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]() if self.inference_config.pad_input:
model_type = "padding_" + self.model_config.model_type
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size) pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
@ -168,7 +172,9 @@ class InferenceEngine:
if prompts_token_ids is None: if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"] prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids, list):
pass pass
@ -237,7 +243,9 @@ class InferenceEngine:
self.v_cache, self.v_cache,
) )
logits = logits[:, -1, :] if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()

View File

@ -0,0 +1,221 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
from typing import List, Optional, Tuple
import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
)
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_kv_to_blocked_cache,
flash_decoding_attention,
get_xine_cache,
rotary_embedding,
)
from colossalai.logging import get_dist_logger
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
logger = get_dist_logger(__name__)
try:
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
@torch.no_grad()
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
batch=batch,
k_caches=k_caches,
v_caches=v_caches,
)
logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1))
return logits
@torch.no_grad()
def llama_model_forward(
self: LlamaModel,
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
input_ids = batch.get_1D_inputs()
block_tables = batch.get_block_table_tensor()
sequence_lengths = batch.get_sequence_lengths()
batch_size = len(sequence_lengths)
kv_seq_len = sequence_lengths.max().item()
hidden_states = self.embed_tokens(input_ids)
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)
for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
block_tables=block_tables,
k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id],
is_prompts=batch.is_prompts,
sequence_lengths=sequence_lengths,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)
if batch.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
hidden_states = self.norm(hidden_states)
return hidden_states
@torch.no_grad()
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
is_prompts=is_prompts,
sequence_lengths=sequence_lengths,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def llama_attn_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view(
-1, self.num_key_value_heads, self.head_dim
)
value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view(
-1, self.num_key_value_heads, self.head_dim
)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
_, _, _, block_size = k_cache.shape
if is_prompts:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.squeeze(1)
attn_output = attn_output.view(-1, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1))
return attn_output
@torch.no_grad()
def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor):
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1))
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1))
tmp_out = act_out * up_proj_out
return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1))

View File

@ -11,6 +11,7 @@ from colossalai.kernel.triton import (
context_attention_unpadded, context_attention_unpadded,
copy_kv_to_blocked_cache, copy_kv_to_blocked_cache,
flash_decoding_attention, flash_decoding_attention,
get_xine_cache,
rotary_embedding, rotary_embedding,
) )
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
@ -101,12 +102,7 @@ def llama_model_forward(
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# When testing, the performance of get_xine_cache is lower than that of get_cos_sin. cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
# cos_sin = (cos, sin)
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)
if batch.is_prompts: if batch.is_prompts:
output_tensor = torch.zeros( output_tensor = torch.zeros(
@ -135,7 +131,9 @@ def llama_model_forward(
sm_scale=sm_scale, sm_scale=sm_scale,
) )
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
@ -327,26 +325,3 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
return (q, k, v, indices) return (q, k, v, indices)
@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
"""
Get cos and sin for the cache, and return nopad format.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
dtype: The data type of this inference process.
"""
if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:
index_arrays = [(length - 1).view(-1) for length in lengths]
indices = torch.cat(index_arrays, dim=-1)
cos_output = cos_cache[indices].to(dtype=dtype)
sin_output = sin_cache[indices].to(dtype=dtype)
return (cos_output, sin_output)

View File

@ -1,7 +1,9 @@
from .llama import LlamaModelInferPolicy from .nopadding_llama import NoPaddingLlamaModelInferPolicy
from .padding_llama import PaddingLlamaModelInferPolicy
model_policy_map = { model_policy_map = {
"llama": LlamaModelInferPolicy, "padding_llama": PaddingLlamaModelInferPolicy,
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
} }
__all__ = ["LlamaModelInferPolicy", "model_polic_map"] __all__ = ["PaddingLlamaModelInferPolicy", "NoPaddingLlamaModelInferPolicy", "model_polic_map"]

View File

@ -0,0 +1,107 @@
from functools import partial
import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
LlamaSdpaAttention,
)
from colossalai.inference.modeling.models.nopadding_llama import (
llama_attn_forward,
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
nopad_mlp,
)
from colossalai.inference.utils import init_to_get_rotary
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
try:
from colossalai.kernel.triton import rms_layernorm
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
return _triton_rmsnorm_forward
else:
return None
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
self.shard_config._infer()
infer_forward = llama_causal_lm_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaForCausalLM
)
infer_forward = llama_model_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = llama_decoder_layer_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = nopad_mlp
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaMLP)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaAttention
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
)
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
)
return policy
def postprocess(self):
init_to_get_rotary(self.model.model)
return self.model

View File

@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import (
LlamaSdpaAttention, LlamaSdpaAttention,
) )
from colossalai.inference.modeling.models.llama import ( from colossalai.inference.modeling.models.padding_llama import (
llama_attn_forward, llama_attn_forward,
llama_causal_lm_forward, llama_causal_lm_forward,
llama_decoder_layer_forward, llama_decoder_layer_forward,
@ -43,7 +43,7 @@ def get_triton_rmsnorm_forward():
return None return None
class LlamaModelInferPolicy(LlamaForCausalLMPolicy): class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()

View File

@ -358,21 +358,16 @@ class BatchInfo:
Flattening the input tokens. Flattening the input tokens.
""" """
input_list = [] input_list = []
input_len_list = []
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
for seq in self.sequences_set: for seq in self.sequences_set:
if self.is_prompts: if self.is_prompts:
input_list.extend(seq.input_token_id) input_list.extend(seq.input_token_id)
input_len_list.append(seq.sentence_len)
else: else:
input_list.append(seq.output_token_id[-1]) input_list.append(seq.output_token_id[-1])
input_len_list.append(1)
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor( return torch.tensor(input_list, dtype=torch.long, device=self.device)
input_len_list, dtype=torch.int, device=self.device
)
def get_sequence_lengths(self): def get_sequence_lengths(self):
""" """
@ -401,7 +396,9 @@ class BatchInfo:
past_values.append(seq.input_token_id + seq.output_token_id) past_values.append(seq.input_token_id + seq.output_token_id)
max_seq_len = max(len(sub_list) for sub_list in past_values) max_seq_len = max(len(sub_list) for sub_list in past_values)
attn_mask = _make_tensor_with_pad(past_values, max_seq_len, 0, dtype=torch.int, device=self.device) attn_mask = _make_tensor_with_pad(
past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device
)
return attn_mask.ne(padding_id).long() return attn_mask.ne(padding_id).long()

View File

@ -2,7 +2,6 @@ import pytest
import torch import torch
from packaging import version from packaging import version
from colossalai.inference.modeling.models.llama import get_cos_sin
from colossalai.kernel.triton import get_xine_cache from colossalai.kernel.triton import get_xine_cache
try: try:
@ -16,6 +15,29 @@ except ImportError:
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
"""
Get cos and sin for the cache, and return nopad format.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
dtype: The data type of this inference process.
"""
if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:
index_arrays = [(length - 1).view(-1) for length in lengths]
indices = torch.cat(index_arrays, dim=-1)
cos_output = cos_cache[indices].to(dtype=dtype)
sin_output = sin_cache[indices].to(dtype=dtype)
return (cos_output, sin_output)
@pytest.mark.parametrize("BATCH_SIZE", [4]) @pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("MAX_SEQ_LEN", [64]) @pytest.mark.parametrize("MAX_SEQ_LEN", [64])
@pytest.mark.parametrize("HEAD_DIM", [64]) @pytest.mark.parametrize("HEAD_DIM", [64])
@ -23,15 +45,18 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype): def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda") cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda") lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda")
# prefill # prefill
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, cos_cache, is_prompts=True, dtype=dtype) cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
cos = get_xine_cache(lengths, cos_cache, is_prompts=True) cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
assert torch.allclose(cos, cos_ref) assert torch.allclose(cos, cos_ref)
assert torch.allclose(sin, sin_ref)
# decoding # decoding
ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, cos_cache, is_prompts=False, dtype=dtype) ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
cos = get_xine_cache(lengths, cos_cache, is_prompts=False) cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False)
assert torch.allclose(cos, ncos_ref) assert torch.allclose(cos, ncos_ref)
assert torch.allclose(sin, sin_ref)
configs = [ configs = [