mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 11:37:14 +00:00
[inference] add llama2 support (#4898)
* add llama2 support * fix multi group bug
This commit is contained in:
parent
39f2582e98
commit
77a9328304
@ -1,7 +1,6 @@
|
|||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BloomForCausalLM, LlamaForCausalLM
|
from transformers import BloomForCausalLM, LlamaForCausalLM
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
@ -74,9 +73,14 @@ class TPInferEngine:
|
|||||||
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
|
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
|
||||||
)
|
)
|
||||||
self.layer_num = num_hidden_layers
|
self.layer_num = num_hidden_layers
|
||||||
self.multi_query_group_num = (
|
|
||||||
model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0
|
self.multi_query_group_num = 0
|
||||||
)
|
|
||||||
|
if hasattr(model.config, "multi_query_group_num"):
|
||||||
|
self.multi_query_group_num = model.config.multi_query_group_num
|
||||||
|
|
||||||
|
if hasattr(model.config, "num_key_value_heads"):
|
||||||
|
self.multi_query_group_num = model.config.num_key_value_heads
|
||||||
|
|
||||||
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
|
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
|
||||||
self.cache_manager = None
|
self.cache_manager = None
|
||||||
@ -97,6 +101,7 @@ class TPInferEngine:
|
|||||||
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
|
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
|
||||||
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
||||||
self.head_num //= self.tp_size # update sharded number of heads
|
self.head_num //= self.tp_size # update sharded number of heads
|
||||||
|
|
||||||
if self.multi_query_group_num:
|
if self.multi_query_group_num:
|
||||||
# NOTE the logic of MQA tensor parallelism should be specified.
|
# NOTE the logic of MQA tensor parallelism should be specified.
|
||||||
assert (
|
assert (
|
||||||
@ -116,13 +121,15 @@ class TPInferEngine:
|
|||||||
|
|
||||||
def _post_init_gptq_buffer(self, model: nn.Module) -> None:
|
def _post_init_gptq_buffer(self, model: nn.Module) -> None:
|
||||||
from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
|
from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
|
||||||
|
|
||||||
HAS_GPTQ_CUDA = False
|
HAS_GPTQ_CUDA = False
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||||
|
|
||||||
gptq_cuda = GPTQBuilder().load()
|
gptq_cuda = GPTQBuilder().load()
|
||||||
HAS_GPTQ_CUDA = True
|
HAS_GPTQ_CUDA = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
warnings.warn('CUDA gptq is not installed')
|
warnings.warn("CUDA gptq is not installed")
|
||||||
HAS_GPTQ_CUDA = False
|
HAS_GPTQ_CUDA = False
|
||||||
|
|
||||||
for name, submodule in model.named_modules():
|
for name, submodule in model.named_modules():
|
||||||
@ -130,8 +137,9 @@ class TPInferEngine:
|
|||||||
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
|
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
|
||||||
|
|
||||||
if self.use_act_order:
|
if self.use_act_order:
|
||||||
self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures,
|
self.max_inner_outer_dim = max(
|
||||||
submodule.outfeatures)
|
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
|
||||||
|
)
|
||||||
self.bits = submodule.bits
|
self.bits = submodule.bits
|
||||||
if not (HAS_GPTQ_CUDA and self.bits == 4):
|
if not (HAS_GPTQ_CUDA and self.bits == 4):
|
||||||
return
|
return
|
||||||
@ -141,15 +149,16 @@ class TPInferEngine:
|
|||||||
max_input_len = self.max_input_len
|
max_input_len = self.max_input_len
|
||||||
# The temp_state buffer is required to reorder X in the act-order case.
|
# The temp_state buffer is required to reorder X in the act-order case.
|
||||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||||
self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim),
|
self.gptq_temp_state_buffer = torch.zeros(
|
||||||
dtype=torch.float16,
|
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
||||||
device=torch.cuda.current_device())
|
)
|
||||||
self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size),
|
self.gptq_temp_dq_buffer = torch.zeros(
|
||||||
dtype=torch.float16,
|
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
|
||||||
device=torch.cuda.current_device())
|
)
|
||||||
|
|
||||||
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer,
|
gptq_cuda.prepare_buffers(
|
||||||
self.gptq_temp_dq_buffer)
|
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
|
||||||
|
)
|
||||||
# Using the default from exllama repo here.
|
# Using the default from exllama repo here.
|
||||||
matmul_recons_thd = 8
|
matmul_recons_thd = 8
|
||||||
matmul_fused_remap = False
|
matmul_fused_remap = False
|
||||||
|
@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
|
|||||||
base = float(base)
|
base = float(base)
|
||||||
|
|
||||||
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None))
|
ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
|
||||||
|
|
||||||
if ntk_alpha is not None:
|
if ntk_alpha is not None:
|
||||||
ntk_alpha = float(ntk_alpha)
|
ntk_alpha = float(ntk_alpha)
|
||||||
|
@ -5,7 +5,13 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
|||||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||||
|
|
||||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd
|
from colossalai.kernel.triton import (
|
||||||
|
llama2_context_attn_fwd,
|
||||||
|
llama_context_attn_fwd,
|
||||||
|
rotary_embedding_fwd,
|
||||||
|
token_attention_fwd,
|
||||||
|
)
|
||||||
|
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||||
|
|
||||||
from ._utils import copy_kv_to_mem_cache
|
from ._utils import copy_kv_to_mem_cache
|
||||||
|
|
||||||
@ -138,6 +144,7 @@ class LlamaInferenceForwards:
|
|||||||
seq_len = infer_state.seq_len
|
seq_len = infer_state.seq_len
|
||||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||||
|
infer_state.other_kv_index = infer_state.block_loc[0, seq_length_with_past - 1].item()
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
@ -261,8 +268,8 @@ class LlamaInferenceForwards:
|
|||||||
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
|
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||||
|
|
||||||
# NOTE might want to revise
|
# NOTE might want to revise
|
||||||
# need some way to record the length of past key values cache
|
# need some way to record the length of past key values cache
|
||||||
@ -274,11 +281,11 @@ class LlamaInferenceForwards:
|
|||||||
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
|
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
|
||||||
|
|
||||||
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||||
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
|
||||||
|
|
||||||
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
||||||
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
|
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||||
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
|
value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||||
|
|
||||||
if infer_state.is_context_stage:
|
if infer_state.is_context_stage:
|
||||||
# first token generation
|
# first token generation
|
||||||
@ -294,6 +301,7 @@ class LlamaInferenceForwards:
|
|||||||
|
|
||||||
attn_output = torch.empty_like(query_states)
|
attn_output = torch.empty_like(query_states)
|
||||||
|
|
||||||
|
if self.num_key_value_groups == 1:
|
||||||
llama_context_attn_fwd(
|
llama_context_attn_fwd(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
@ -303,6 +311,16 @@ class LlamaInferenceForwards:
|
|||||||
infer_state.seq_len,
|
infer_state.seq_len,
|
||||||
infer_state.cache_manager.past_key_values_length,
|
infer_state.cache_manager.past_key_values_length,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
llama2_context_attn_fwd(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_output,
|
||||||
|
infer_state.start_loc,
|
||||||
|
infer_state.seq_len,
|
||||||
|
infer_state.cache_manager.past_key_values_length,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if infer_state.decode_is_contiguous:
|
if infer_state.decode_is_contiguous:
|
||||||
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||||
@ -330,6 +348,7 @@ class LlamaInferenceForwards:
|
|||||||
# (batch_size, seqlen, nheads, headdim)
|
# (batch_size, seqlen, nheads, headdim)
|
||||||
attn_output = torch.empty_like(query_states)
|
attn_output = torch.empty_like(query_states)
|
||||||
|
|
||||||
|
if self.num_key_value_groups == 1:
|
||||||
token_attention_fwd(
|
token_attention_fwd(
|
||||||
query_states,
|
query_states,
|
||||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||||
@ -340,7 +359,18 @@ class LlamaInferenceForwards:
|
|||||||
infer_state.seq_len,
|
infer_state.seq_len,
|
||||||
infer_state.cache_manager.past_key_values_length,
|
infer_state.cache_manager.past_key_values_length,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
Llama2TokenAttentionForwards.token_attn(
|
||||||
|
query_states,
|
||||||
|
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||||
|
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||||
|
attn_output,
|
||||||
|
infer_state.block_loc,
|
||||||
|
infer_state.start_loc,
|
||||||
|
infer_state.seq_len,
|
||||||
|
infer_state.cache_manager.past_key_values_length,
|
||||||
|
infer_state.other_kv_index,
|
||||||
|
)
|
||||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
@ -9,7 +9,7 @@ except ImportError:
|
|||||||
|
|
||||||
# There may exist import error even if we have triton installed.
|
# There may exist import error even if we have triton installed.
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
|
from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd
|
||||||
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||||
from .fused_layernorm import layer_norm
|
from .fused_layernorm import layer_norm
|
||||||
from .gptq_triton import gptq_fused_linear_triton
|
from .gptq_triton import gptq_fused_linear_triton
|
||||||
@ -20,6 +20,7 @@ if HAS_TRITON:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"llama_context_attn_fwd",
|
"llama_context_attn_fwd",
|
||||||
|
"llama2_context_attn_fwd",
|
||||||
"bloom_context_attn_fwd",
|
"bloom_context_attn_fwd",
|
||||||
"softmax",
|
"softmax",
|
||||||
"layer_norm",
|
"layer_norm",
|
||||||
|
69
tests/test_infer/test_llama2_infer.py
Normal file
69
tests/test_infer/test_llama2_infer.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from transformers import LlamaForCausalLM
|
||||||
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer import ShardConfig
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
|
TPSIZE = 2
|
||||||
|
BATCH_SIZE = 8
|
||||||
|
MAX_INPUT_LEN = 12
|
||||||
|
MAX_OUTPUT_LEN = 100
|
||||||
|
|
||||||
|
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": TPSIZE,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_llama_test(test_config):
|
||||||
|
llama_config = LlamaConfig(
|
||||||
|
num_hidden_layers=2, num_key_value_heads=8, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024
|
||||||
|
)
|
||||||
|
model = LlamaForCausalLM(llama_config)
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
shard_config = ShardConfig(
|
||||||
|
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||||
|
)
|
||||||
|
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
|
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||||
|
|
||||||
|
input_tokens = {
|
||||||
|
"input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||||
|
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||||
|
}
|
||||||
|
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
|
||||||
|
def check_llama(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_llama_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_llama():
|
||||||
|
spawn(check_llama, TPSIZE)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_llama()
|
Loading…
Reference in New Issue
Block a user