Add Inference test for llama (#4508)

* add kv cache memory manager

* add stateinfo during inference

* add

* add infer example

* finish

* finish

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* add inference test for llama

* fix conflict

* feature: add some new features for llama engine

* adapt colossalai triton interface

* Change the parent class of llama  policy

* add nvtx

* move llama inference code to tensor_parallel

* fix __init__.py

* rm tensor_parallel

* fix: fix bugs in auto_policy.py

* fix:rm some unused codes

* mv colossalai/tpinference to colossalai/inference/tensor_parallel

* change __init__.py

* save change

* fix engine

* Bug fix: Fix hang

* remove llama_infer_engine.py

---------

Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
yuehuayingxueluo 2023-08-30 12:10:26 +08:00 committed by GitHub
parent 35af65d240
commit f0aab7f9a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 438 additions and 134 deletions

View File

@ -1,4 +1,6 @@
from .modeling.llama import LlamaInferenceForwards
from .pollcies.llama import LlamaModelInferPolicy
from .engine import TPInferEngine
from .kvcache_manager import MemoryManager
__all__ = ['MemoryManager', 'TPInferEngine']
__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'MemoryManager', 'TPInferEngine']

View File

@ -16,7 +16,7 @@ from .kvcache_manager import MemoryManager
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
_supported_models = ['LlamaForCausalLM', 'BloomForCausalLM']
_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']
class TPInferEngine:
@ -27,7 +27,7 @@ class TPInferEngine:
max_input_len: int,
max_output_len: int,
dtype: torch.dtype = torch.float16,
device: torch.device = torch.cuda.current_device()) -> None:
device: str = 'cuda') -> None:
self.model = model
self.sharded_model = None
@ -40,7 +40,7 @@ class TPInferEngine:
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
assert self.max_input_len + self.max_output_len <= 2048, "Max length exceeds the constraint"
self.device = device
torch.device(device=device)
self.dtype = dtype
self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
@ -88,7 +88,7 @@ class TPInferEngine:
assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(self.model, inference_only=True)
self.sharded_model, _ = shardformer.optimize(self.model, policy)
self.sharded_model = self.sharded_model.to(self.device)
self.sharded_model = self.sharded_model.cuda()
@staticmethod
def _supported_models() -> List[str]:
@ -137,7 +137,7 @@ class TPInferEngine:
input_tokens = dict(input_ids=input_tokens)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(self.device)
input_tokens[t] = input_tokens[t].cuda()
outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
@ -173,8 +173,8 @@ class TPInferEngine:
else:
batch_size = inputs.shape[0]
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device=self.device)
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device=self.device)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
start_index = 0
max_len_in_batch = -1
@ -197,10 +197,10 @@ class TPInferEngine:
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len),
dtype=torch.long,
device=self.device)
device='cuda')
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to(self.device) # might want to assign specific device
batch_infer_state.start_loc = seq_start_indexes.to(self.device)
batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device
batch_infer_state.start_loc = seq_start_indexes.to('cuda')
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0
@ -251,4 +251,4 @@ class TPInferEngine:
# => put information already recorded in batchinferstate and pass it to model forward
# => clear records in engine
def add_request():
raise NotImplementedError()
raise NotImplementedError()

View File

@ -0,0 +1,3 @@
from .llama import LlamaInferenceForwards
__all__ = ['LlamaInferenceForwards']

View File

@ -0,0 +1,321 @@
from typing import List, Optional, Tuple
import torch
import numpy as np
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
from typing import List, Optional, Tuple
from transformers.modeling_outputs import BaseModelOutputWithPast
class LlamaInferenceForwards:
"""
This class holds forwards for llama inference.
"""
@staticmethod
def llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
batch_size = input_ids.shape[0] # input_ids.shape[0]
# infer_state = BatchInferState(batch_size, input_ids.shape[1])
# infer_state.batch_size = batch_size
# # NOTE: dummy implementation here for testing, just assume all inputs same length
# infer_state.block_loc = self.block_loc
# infer_state.start_loc = self.start_loc
# infer_state.seq_len = self.seq_len
# infer_state.max_len_in_batch = self.max_len_in_batch
infer_state = self.infer_state
b_seq_len_numpy = infer_state.seq_len.cpu().numpy()
position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i])
for i in range(len(b_seq_len_numpy))], axis=0)).cuda()
# this equals
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
# TODO dummy but work, revise it
past_key_values_length = infer_state.cache_manager.past_key_values_length
# past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
# FIXME: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
if use_cache and seq_length != 1:
# NOTE assuem prefill stage
# allocate memory block
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index)
else:
# TODO handle the condition that no contiguous memory presents
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
infer_state.decode_is_contiguous = True
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
infer_state.decode_layer_id = 0
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx] if past_key_values is not None else None
# NOTE: modify here for passing args to decoder layer
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
infer_state=infer_state,
)
infer_state.decode_layer_id += 1
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
hidden_states = self.norm(hidden_states)
next_cache = next_decoder_cache if use_cache else None
# update indices
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@staticmethod
def llama_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
infer_state=infer_state,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
@staticmethod
def llama_flash_attn_kvcache_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert use_cache is True, "use_cache should be set to True using this llama attention"
bsz, q_len, _ = hidden_states.size()
# TODO might think about better way to handle transposed k and v
# key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states_transposed = key_states.transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
# cos, sin = self.rotary_emb(value_states_transposed, seq_len=kv_seq_len)
cos ,sin = infer_state.position_cos, infer_state.position_sin
cos_sin_cache = torch.cat((cos, sin), dim=-1)
from vllm.pos_encoding_ops import rotary_embedding_neox
rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache)
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
num_heads = key_buffer.shape[2]
head_dim = key_buffer.shape[3]
key_buffer = key_buffer.view(-1, num_heads, head_dim)
value_buffer = value_buffer.view(-1, num_heads, head_dim)
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return
# copy key and value calculated in current step to memory manager
if infer_state.is_context_stage:
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, infer_state.cache_manager)
else:
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index, infer_state.cache_manager)
# this is worse than destcopy
# torch.Tensor.copy_(infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],key_states)
# torch.Tensor.copy_(infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],value_states)
# FIXME might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
query_states = query_states.transpose(1, 2)
if infer_state.is_context_stage:
# first token generation
# attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states,
# key_states,
# value_states,
# 0,
# 1/math.sqrt(self.head_dim),
# causal,
# False)
attn_output = torch.empty_like(query_states)
# calcu_shape for context_attention_fwd
calcu_shape1 = (-1, self.num_heads, self.head_dim)
llama_context_attn_fwd(query_states.view(calcu_shape1),
key_states.view(calcu_shape1),
value_states.view(calcu_shape1),
attn_output.view(calcu_shape1),
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length)
else:
# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)
token_attention_fwd(query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
# return past_key_value as None
return attn_output, None, None

View File

@ -0,0 +1,3 @@
from .llama import LlamaModelInferPolicy
__all__ = ['LlamaModelInferPolicy']

View File

@ -0,0 +1,35 @@
from functools import partial
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling.llama import LlamaInferenceForwards
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
policy = super().module_policy()
self.shard_config._infer()
# example for replace layer or decoder
# if self.shard_config.enable_flash_attention:
# policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
# 'forward': get_llama_flash_attention_forward(),
# })
infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer)
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention)
return policy

View File

@ -391,84 +391,6 @@ class LlamaPipelineForwards:
return {'hidden_states': hidden_states}
class LlamaInferenceForwards:
"""
This class holds forwards for llama inference.
"""
@staticmethod
def llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[
torch.LongTensor] = None, # TODO: this can also be removed if we got sin,cos cached in inferinfo
past_key_values: Optional[List[
torch.FloatTensor]] = None, #TODO: maybe removed after memory cache manager is done.
inputs_embeds: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
inferinfo=None,
):
# only keep the basic items
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
if not return_dict:
return hidden_states
return BaseModelOutputWithPast(last_hidden_state=hidden_states,)
def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb

View File

@ -140,11 +140,15 @@ _INFER_POLICY_LIST = {
}
def import_policy(policy_location: PolicyLocation) -> Policy:
def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
if inference_only:
module_name = f"colossalai.inference.tensor_parallel.pollcies.{policy_location.file_name}"
else:
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)
@ -181,5 +185,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location)
policy = import_policy(policy_location, inference_only)
return policy()

View File

@ -7,7 +7,7 @@ from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
@ -263,21 +263,3 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []
class LlamaModelInferPolicy(LlamaPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
policy = super().module_policy()
# configure default shard config for inference
self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
return policy

View File

@ -2,40 +2,72 @@ import os
import pytest
import torch
from torch import distributed as dist
import numpy as np
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_infer._utils import build_model, run_infer
from transformers import LlamaForCausalLM, LlamaTokenizer
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.inference.tensor_parallel.engine import TPInferEngine
import torch.distributed as dist
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
TPSIZE = 2
def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config,"max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config,"max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config):
org_model, sharded_model = build_model(model_fn, **test_config)
org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn)
print('original output', org_output[0])
print('infer output', infer_output[0])
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return
@parameterize('test_config', [{
'enable_flash_attention': False,
'tp_size': TPSIZE,
}])
def run_llama_test(test_config):
llama_model_path = "/data/scratch/llama-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
init_to_get_rotary(model.model, base=10000)
model = model.half()
model.to(torch.cuda.current_device())
text = "Introduce some landmarks in Beijing"
input_ids = tokenizer.encode(text, return_tensors='pt')
# pg_mesh = ProcessGroupMesh(1, 1, test_config["tp_size"])
infer_engine = TPInferEngine(model.half(), 4, 12, 8)
shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
infer_engine.prepare_with_shard_config(shard_config)
infer_engine.shard_model_by(shardformer)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name != "transformers_llama":
continue
check_infer(model_fn, data_gen_fn, output_transform_fn, test_config)
torch.cuda.empty_cache()
generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(input_ids, generate_kwargs)
print("outputs: ", outputs)
output_text = tokenizer.decode(outputs[0])
print(output_text)
def check_llama(rank, world_size, port):
@ -48,7 +80,7 @@ def check_llama(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, 1)
spawn(check_llama, TPSIZE)
if __name__ == "__main__":