[inference] add llama2 support (#4898)

* add llama2 support

* fix multi group bug
This commit is contained in:
Xu Kai 2023-10-13 13:09:23 +08:00 committed by GitHub
parent 39f2582e98
commit 77a9328304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 152 additions and 43 deletions

View File

@ -1,7 +1,6 @@
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from transformers import BloomForCausalLM, LlamaForCausalLM from transformers import BloomForCausalLM, LlamaForCausalLM
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
@ -74,9 +73,14 @@ class TPInferEngine:
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
) )
self.layer_num = num_hidden_layers self.layer_num = num_hidden_layers
self.multi_query_group_num = (
model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0 self.multi_query_group_num = 0
)
if hasattr(model.config, "multi_query_group_num"):
self.multi_query_group_num = model.config.multi_query_group_num
if hasattr(model.config, "num_key_value_heads"):
self.multi_query_group_num = model.config.num_key_value_heads
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None self.cache_manager = None
@ -97,6 +101,7 @@ class TPInferEngine:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads self.head_num //= self.tp_size # update sharded number of heads
if self.multi_query_group_num: if self.multi_query_group_num:
# NOTE the logic of MQA tensor parallelism should be specified. # NOTE the logic of MQA tensor parallelism should be specified.
assert ( assert (
@ -116,13 +121,15 @@ class TPInferEngine:
def _post_init_gptq_buffer(self, model: nn.Module) -> None: def _post_init_gptq_buffer(self, model: nn.Module) -> None:
from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
try: try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load() gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True HAS_GPTQ_CUDA = True
except ImportError: except ImportError:
warnings.warn('CUDA gptq is not installed') warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
for name, submodule in model.named_modules(): for name, submodule in model.named_modules():
@ -130,8 +137,9 @@ class TPInferEngine:
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
if self.use_act_order: if self.use_act_order:
self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, self.max_inner_outer_dim = max(
submodule.outfeatures) self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
)
self.bits = submodule.bits self.bits = submodule.bits
if not (HAS_GPTQ_CUDA and self.bits == 4): if not (HAS_GPTQ_CUDA and self.bits == 4):
return return
@ -141,15 +149,16 @@ class TPInferEngine:
max_input_len = self.max_input_len max_input_len = self.max_input_len
# The temp_state buffer is required to reorder X in the act-order case. # The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim), self.gptq_temp_state_buffer = torch.zeros(
dtype=torch.float16, (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
device=torch.cuda.current_device()) )
self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size), self.gptq_temp_dq_buffer = torch.zeros(
dtype=torch.float16, (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
device=torch.cuda.current_device()) )
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, gptq_cuda.prepare_buffers(
self.gptq_temp_dq_buffer) torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
)
# Using the default from exllama repo here. # Using the default from exllama repo here.
matmul_recons_thd = 8 matmul_recons_thd = 8
matmul_fused_remap = False matmul_fused_remap = False

View File

@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
base = float(base) base = float(base)
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None)) ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
if ntk_alpha is not None: if ntk_alpha is not None:
ntk_alpha = float(ntk_alpha) ntk_alpha = float(ntk_alpha)

View File

@ -5,7 +5,13 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd from colossalai.kernel.triton import (
llama2_context_attn_fwd,
llama_context_attn_fwd,
rotary_embedding_fwd,
token_attention_fwd,
)
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from ._utils import copy_kv_to_mem_cache from ._utils import copy_kv_to_mem_cache
@ -138,6 +144,7 @@ class LlamaInferenceForwards:
seq_len = infer_state.seq_len seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.other_kv_index = infer_state.block_loc[0, seq_length_with_past - 1].item()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
@ -261,8 +268,8 @@ class LlamaInferenceForwards:
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
# NOTE might want to revise # NOTE might want to revise
# need some way to record the length of past key values cache # need some way to record the length of past key values cache
@ -274,11 +281,11 @@ class LlamaInferenceForwards:
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim) query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim) value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)
if infer_state.is_context_stage: if infer_state.is_context_stage:
# first token generation # first token generation
@ -294,6 +301,7 @@ class LlamaInferenceForwards:
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
if self.num_key_value_groups == 1:
llama_context_attn_fwd( llama_context_attn_fwd(
query_states, query_states,
key_states, key_states,
@ -303,6 +311,16 @@ class LlamaInferenceForwards:
infer_state.seq_len, infer_state.seq_len,
infer_state.cache_manager.past_key_values_length, infer_state.cache_manager.past_key_values_length,
) )
else:
llama2_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
else: else:
if infer_state.decode_is_contiguous: if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
@ -330,6 +348,7 @@ class LlamaInferenceForwards:
# (batch_size, seqlen, nheads, headdim) # (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
if self.num_key_value_groups == 1:
token_attention_fwd( token_attention_fwd(
query_states, query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
@ -340,7 +359,18 @@ class LlamaInferenceForwards:
infer_state.seq_len, infer_state.seq_len,
infer_state.cache_manager.past_key_values_length, infer_state.cache_manager.past_key_values_length,
) )
else:
Llama2TokenAttentionForwards.token_attn(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.other_kv_index,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)

View File

@ -9,7 +9,7 @@ except ImportError:
# There may exist import error even if we have triton installed. # There may exist import error even if we have triton installed.
if HAS_TRITON: if HAS_TRITON:
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton from .gptq_triton import gptq_fused_linear_triton
@ -20,6 +20,7 @@ if HAS_TRITON:
__all__ = [ __all__ = [
"llama_context_attn_fwd", "llama_context_attn_fwd",
"llama2_context_attn_fwd",
"bloom_context_attn_fwd", "bloom_context_attn_fwd",
"softmax", "softmax",
"layer_norm", "layer_norm",

View File

@ -0,0 +1,69 @@
import os
import pytest
import torch
from packaging import version
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 2
BATCH_SIZE = 8
MAX_INPUT_LEN = 12
MAX_OUTPUT_LEN = 100
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
@parameterize(
"test_config",
[
{
"tp_size": TPSIZE,
}
],
)
def run_llama_test(test_config):
llama_config = LlamaConfig(
num_hidden_layers=2, num_key_value_heads=8, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024
)
model = LlamaForCausalLM(llama_config)
model = model.half()
shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
input_tokens = {
"input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
}
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
assert outputs is not None
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_llama_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, TPSIZE)
if __name__ == "__main__":
test_llama()