mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 07:47:05 +00:00
Add Inference test for llama (#4508)
* add kv cache memory manager * add stateinfo during inference * add * add infer example * finish * finish * format * format * rename file * add kv cache test * revise on BatchInferState * add inference test for llama * fix conflict * feature: add some new features for llama engine * adapt colossalai triton interface * Change the parent class of llama policy * add nvtx * move llama inference code to tensor_parallel * fix __init__.py * rm tensor_parallel * fix: fix bugs in auto_policy.py * fix:rm some unused codes * mv colossalai/tpinference to colossalai/inference/tensor_parallel * change __init__.py * save change * fix engine * Bug fix: Fix hang * remove llama_infer_engine.py --------- Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
parent
35af65d240
commit
f0aab7f9a8
@ -1,4 +1,6 @@
|
||||
from .modeling.llama import LlamaInferenceForwards
|
||||
from .pollcies.llama import LlamaModelInferPolicy
|
||||
from .engine import TPInferEngine
|
||||
from .kvcache_manager import MemoryManager
|
||||
|
||||
__all__ = ['MemoryManager', 'TPInferEngine']
|
||||
__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'MemoryManager', 'TPInferEngine']
|
||||
|
@ -16,7 +16,7 @@ from .kvcache_manager import MemoryManager
|
||||
|
||||
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
_supported_models = ['LlamaForCausalLM', 'BloomForCausalLM']
|
||||
_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']
|
||||
|
||||
|
||||
class TPInferEngine:
|
||||
@ -27,7 +27,7 @@ class TPInferEngine:
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: torch.device = torch.cuda.current_device()) -> None:
|
||||
device: str = 'cuda') -> None:
|
||||
self.model = model
|
||||
self.sharded_model = None
|
||||
|
||||
@ -40,7 +40,7 @@ class TPInferEngine:
|
||||
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||
assert self.max_input_len + self.max_output_len <= 2048, "Max length exceeds the constraint"
|
||||
|
||||
self.device = device
|
||||
torch.device(device=device)
|
||||
self.dtype = dtype
|
||||
|
||||
self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
|
||||
@ -88,7 +88,7 @@ class TPInferEngine:
|
||||
assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference."
|
||||
policy = get_autopolicy(self.model, inference_only=True)
|
||||
self.sharded_model, _ = shardformer.optimize(self.model, policy)
|
||||
self.sharded_model = self.sharded_model.to(self.device)
|
||||
self.sharded_model = self.sharded_model.cuda()
|
||||
|
||||
@staticmethod
|
||||
def _supported_models() -> List[str]:
|
||||
@ -137,7 +137,7 @@ class TPInferEngine:
|
||||
input_tokens = dict(input_ids=input_tokens)
|
||||
for t in input_tokens:
|
||||
if torch.is_tensor(input_tokens[t]):
|
||||
input_tokens[t] = input_tokens[t].to(self.device)
|
||||
input_tokens[t] = input_tokens[t].cuda()
|
||||
|
||||
outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
|
||||
|
||||
@ -173,8 +173,8 @@ class TPInferEngine:
|
||||
else:
|
||||
batch_size = inputs.shape[0]
|
||||
|
||||
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device=self.device)
|
||||
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device=self.device)
|
||||
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
||||
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
||||
start_index = 0
|
||||
|
||||
max_len_in_batch = -1
|
||||
@ -197,10 +197,10 @@ class TPInferEngine:
|
||||
|
||||
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len),
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
device='cuda')
|
||||
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
|
||||
batch_infer_state.seq_len = seq_lengths.to(self.device) # might want to assign specific device
|
||||
batch_infer_state.start_loc = seq_start_indexes.to(self.device)
|
||||
batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device
|
||||
batch_infer_state.start_loc = seq_start_indexes.to('cuda')
|
||||
batch_infer_state.block_loc = block_loc
|
||||
batch_infer_state.decode_layer_id = 0
|
||||
batch_infer_state.past_key_values_len = 0
|
||||
|
@ -0,0 +1,3 @@
|
||||
from .llama import LlamaInferenceForwards
|
||||
|
||||
__all__ = ['LlamaInferenceForwards']
|
321
colossalai/inference/tensor_parallel/modeling/llama.py
Normal file
321
colossalai/inference/tensor_parallel/modeling/llama.py
Normal file
@ -0,0 +1,321 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
||||
from typing import List, Optional, Tuple
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
||||
class LlamaInferenceForwards:
|
||||
"""
|
||||
This class holds forwards for llama inference.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
|
||||
batch_size = input_ids.shape[0] # input_ids.shape[0]
|
||||
|
||||
# infer_state = BatchInferState(batch_size, input_ids.shape[1])
|
||||
# infer_state.batch_size = batch_size
|
||||
# # NOTE: dummy implementation here for testing, just assume all inputs same length
|
||||
# infer_state.block_loc = self.block_loc
|
||||
# infer_state.start_loc = self.start_loc
|
||||
# infer_state.seq_len = self.seq_len
|
||||
# infer_state.max_len_in_batch = self.max_len_in_batch
|
||||
|
||||
infer_state = self.infer_state
|
||||
b_seq_len_numpy = infer_state.seq_len.cpu().numpy()
|
||||
position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i])
|
||||
for i in range(len(b_seq_len_numpy))], axis=0)).cuda()
|
||||
|
||||
# this equals
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
# TODO dummy but work, revise it
|
||||
past_key_values_length = infer_state.cache_manager.past_key_values_length
|
||||
# past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
# FIXME: differentiate with prefill stage
|
||||
# block_loc require different value-assigning method for two different stage
|
||||
if use_cache and seq_length != 1:
|
||||
# NOTE assuem prefill stage
|
||||
# allocate memory block
|
||||
infer_state.is_context_stage = True # set prefill stage, notify attention layer
|
||||
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
|
||||
infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index)
|
||||
else:
|
||||
# TODO handle the condition that no contiguous memory presents
|
||||
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
|
||||
if alloc_mem is not None:
|
||||
infer_state.decode_is_contiguous = True
|
||||
infer_state.decode_mem_index = alloc_mem[0]
|
||||
infer_state.decode_mem_start = alloc_mem[1]
|
||||
infer_state.decode_mem_end = alloc_mem[2]
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
else:
|
||||
print(f" *** Encountered allocation non-contiguous")
|
||||
print(f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}")
|
||||
infer_state.decode_is_contiguous = False
|
||||
alloc_mem = infer_state.cache_manager.alloc(batch_size)
|
||||
infer_state.decode_mem_index = alloc_mem
|
||||
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
|
||||
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
infer_state.decode_layer_id = 0
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
# NOTE: modify here for passing args to decoder layer
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
infer_state.decode_layer_id += 1
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
# update indices
|
||||
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.seq_len += 1
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def llama_decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@staticmethod
|
||||
def llama_flash_attn_kvcache_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
assert use_cache is True, "use_cache should be set to True using this llama attention"
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# TODO might think about better way to handle transposed k and v
|
||||
# key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
|
||||
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states_transposed = key_states.transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
# cos, sin = self.rotary_emb(value_states_transposed, seq_len=kv_seq_len)
|
||||
cos ,sin = infer_state.position_cos, infer_state.position_sin
|
||||
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
|
||||
from vllm.pos_encoding_ops import rotary_embedding_neox
|
||||
|
||||
rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache)
|
||||
|
||||
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
||||
num_heads = key_buffer.shape[2]
|
||||
head_dim = key_buffer.shape[3]
|
||||
key_buffer = key_buffer.view(-1, num_heads, head_dim)
|
||||
value_buffer = value_buffer.view(-1, num_heads, head_dim)
|
||||
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
||||
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
||||
return
|
||||
|
||||
# copy key and value calculated in current step to memory manager
|
||||
if infer_state.is_context_stage:
|
||||
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, infer_state.cache_manager)
|
||||
else:
|
||||
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index, infer_state.cache_manager)
|
||||
|
||||
# this is worse than destcopy
|
||||
# torch.Tensor.copy_(infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],key_states)
|
||||
# torch.Tensor.copy_(infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],value_states)
|
||||
|
||||
# FIXME might want to revise
|
||||
# need some way to record the length of past key values cache
|
||||
# since we won't return past_key_value_cache right now
|
||||
if infer_state.decode_layer_id == 0: # once per model.forward
|
||||
infer_state.cache_manager.past_key_values_length += q_len # seq_len
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
# first token generation
|
||||
|
||||
# attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states,
|
||||
# key_states,
|
||||
# value_states,
|
||||
# 0,
|
||||
# 1/math.sqrt(self.head_dim),
|
||||
# causal,
|
||||
# False)
|
||||
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
# calcu_shape for context_attention_fwd
|
||||
calcu_shape1 = (-1, self.num_heads, self.head_dim)
|
||||
|
||||
llama_context_attn_fwd(query_states.view(calcu_shape1),
|
||||
key_states.view(calcu_shape1),
|
||||
value_states.view(calcu_shape1),
|
||||
attn_output.view(calcu_shape1),
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length)
|
||||
else:
|
||||
# second token and follows
|
||||
# kv = torch.stack((key_states, value_states), dim=2)
|
||||
# (batch_size, seqlen, nheads, headdim)
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
token_attention_fwd(query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
# return past_key_value as None
|
||||
return attn_output, None, None
|
||||
|
||||
|
@ -0,0 +1,3 @@
|
||||
from .llama import LlamaModelInferPolicy
|
||||
|
||||
__all__ = ['LlamaModelInferPolicy']
|
35
colossalai/inference/tensor_parallel/pollcies/llama.py
Normal file
35
colossalai/inference/tensor_parallel/pollcies/llama.py
Normal file
@ -0,0 +1,35 @@
|
||||
from functools import partial
|
||||
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
from ..modeling.llama import LlamaInferenceForwards
|
||||
|
||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
||||
policy = super().module_policy()
|
||||
self.shard_config._infer()
|
||||
|
||||
# example for replace layer or decoder
|
||||
# if self.shard_config.enable_flash_attention:
|
||||
# policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
|
||||
# 'forward': get_llama_flash_attention_forward(),
|
||||
# })
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_model_forward
|
||||
method_replacement = {'forward': partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
|
||||
method_replacement = {'forward': partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer)
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
|
||||
method_replacement = {'forward': partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention)
|
||||
|
||||
return policy
|
@ -391,84 +391,6 @@ class LlamaPipelineForwards:
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
class LlamaInferenceForwards:
|
||||
"""
|
||||
This class holds forwards for llama inference.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[
|
||||
torch.LongTensor] = None, # TODO: this can also be removed if we got sin,cos cached in inferinfo
|
||||
past_key_values: Optional[List[
|
||||
torch.FloatTensor]] = None, #TODO: maybe removed after memory cache manager is done.
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
inferinfo=None,
|
||||
):
|
||||
# only keep the basic items
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device)
|
||||
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
|
||||
past_key_values_length)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return hidden_states
|
||||
return BaseModelOutputWithPast(last_hidden_state=hidden_states,)
|
||||
|
||||
|
||||
def get_llama_flash_attention_forward():
|
||||
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
@ -140,11 +140,15 @@ _INFER_POLICY_LIST = {
|
||||
}
|
||||
|
||||
|
||||
def import_policy(policy_location: PolicyLocation) -> Policy:
|
||||
def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy:
|
||||
"""
|
||||
Dynamically import a Policy class based on the policy location.
|
||||
"""
|
||||
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
|
||||
|
||||
if inference_only:
|
||||
module_name = f"colossalai.inference.tensor_parallel.pollcies.{policy_location.file_name}"
|
||||
else:
|
||||
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, policy_location.class_name)
|
||||
|
||||
@ -181,5 +185,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
|
||||
)
|
||||
else:
|
||||
policy = import_policy(policy_location)
|
||||
policy = import_policy(policy_location, inference_only)
|
||||
return policy()
|
||||
|
@ -7,7 +7,7 @@ from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
||||
@ -263,21 +263,3 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama for sequence classification model"""
|
||||
return []
|
||||
|
||||
|
||||
class LlamaModelInferPolicy(LlamaPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
||||
policy = super().module_policy()
|
||||
# configure default shard config for inference
|
||||
self.shard_config._infer()
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_model_forward
|
||||
method_replacement = {'forward': partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
||||
|
||||
return policy
|
||||
|
@ -2,40 +2,72 @@ import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
import numpy as np
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_infer._utils import build_model, run_infer
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
import torch.distributed as dist
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
TPSIZE = 2
|
||||
|
||||
def init_to_get_rotary(self, base=10000):
|
||||
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
||||
if not hasattr(self.config, "rope_scaling"):
|
||||
rope_scaling_factor = 1.0
|
||||
else:
|
||||
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
||||
if hasattr(self.config,"max_sequence_length"):
|
||||
max_seq_len = self.config.max_sequence_length
|
||||
elif hasattr(self.config,"max_position_embeddings"):
|
||||
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
||||
else:
|
||||
max_seq_len = 2048 * rope_scaling_factor
|
||||
base = float(base)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_))
|
||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config):
|
||||
org_model, sharded_model = build_model(model_fn, **test_config)
|
||||
|
||||
org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn)
|
||||
|
||||
print('original output', org_output[0])
|
||||
print('infer output', infer_output[0])
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
||||
return
|
||||
|
||||
@parameterize('test_config', [{
|
||||
'enable_flash_attention': False,
|
||||
'tp_size': TPSIZE,
|
||||
}])
|
||||
def run_llama_test(test_config):
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
llama_model_path = "/data/scratch/llama-7b-hf"
|
||||
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
|
||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
||||
init_to_get_rotary(model.model, base=10000)
|
||||
model = model.half()
|
||||
model.to(torch.cuda.current_device())
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name != "transformers_llama":
|
||||
continue
|
||||
check_infer(model_fn, data_gen_fn, output_transform_fn, test_config)
|
||||
torch.cuda.empty_cache()
|
||||
text = "Introduce some landmarks in Beijing"
|
||||
input_ids = tokenizer.encode(text, return_tensors='pt')
|
||||
# pg_mesh = ProcessGroupMesh(1, 1, test_config["tp_size"])
|
||||
|
||||
infer_engine = TPInferEngine(model.half(), 4, 12, 8)
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
|
||||
shardformer = ShardFormer(shard_config=shard_config)
|
||||
|
||||
infer_engine.prepare_with_shard_config(shard_config)
|
||||
infer_engine.shard_model_by(shardformer)
|
||||
|
||||
generate_kwargs = dict(do_sample=False)
|
||||
outputs = infer_engine.generate(input_ids, generate_kwargs)
|
||||
|
||||
print("outputs: ", outputs)
|
||||
|
||||
output_text = tokenizer.decode(outputs[0])
|
||||
print(output_text)
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
@ -48,7 +80,7 @@ def check_llama(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, 1)
|
||||
spawn(check_llama, TPSIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user