diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py index e467b4c73..1535db4c1 100644 --- a/colossalai/inference/tensor_parallel/__init__.py +++ b/colossalai/inference/tensor_parallel/__init__.py @@ -1,4 +1,6 @@ +from .modeling.llama import LlamaInferenceForwards +from .pollcies.llama import LlamaModelInferPolicy from .engine import TPInferEngine from .kvcache_manager import MemoryManager - -__all__ = ['MemoryManager', 'TPInferEngine'] + +__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'MemoryManager', 'TPInferEngine'] diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index f643d892a..e833ef3bd 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -16,7 +16,7 @@ from .kvcache_manager import MemoryManager DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 -_supported_models = ['LlamaForCausalLM', 'BloomForCausalLM'] +_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] class TPInferEngine: @@ -27,7 +27,7 @@ class TPInferEngine: max_input_len: int, max_output_len: int, dtype: torch.dtype = torch.float16, - device: torch.device = torch.cuda.current_device()) -> None: + device: str = 'cuda') -> None: self.model = model self.sharded_model = None @@ -40,7 +40,7 @@ class TPInferEngine: assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" assert self.max_input_len + self.max_output_len <= 2048, "Max length exceeds the constraint" - self.device = device + torch.device(device=device) self.dtype = dtype self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads @@ -88,7 +88,7 @@ class TPInferEngine: assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(self.model, inference_only=True) self.sharded_model, _ = shardformer.optimize(self.model, policy) - self.sharded_model = self.sharded_model.to(self.device) + self.sharded_model = self.sharded_model.cuda() @staticmethod def _supported_models() -> List[str]: @@ -137,7 +137,7 @@ class TPInferEngine: input_tokens = dict(input_ids=input_tokens) for t in input_tokens: if torch.is_tensor(input_tokens[t]): - input_tokens[t] = input_tokens[t].to(self.device) + input_tokens[t] = input_tokens[t].cuda() outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) @@ -173,8 +173,8 @@ class TPInferEngine: else: batch_size = inputs.shape[0] - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device=self.device) - seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device=self.device) + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda') + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda') start_index = 0 max_len_in_batch = -1 @@ -197,10 +197,10 @@ class TPInferEngine: block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, - device=self.device) + device='cuda') batch_infer_state = BatchInferState(batch_size, max_len_in_batch) - batch_infer_state.seq_len = seq_lengths.to(self.device) # might want to assign specific device - batch_infer_state.start_loc = seq_start_indexes.to(self.device) + batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device + batch_infer_state.start_loc = seq_start_indexes.to('cuda') batch_infer_state.block_loc = block_loc batch_infer_state.decode_layer_id = 0 batch_infer_state.past_key_values_len = 0 @@ -251,4 +251,4 @@ class TPInferEngine: # => put information already recorded in batchinferstate and pass it to model forward # => clear records in engine def add_request(): - raise NotImplementedError() + raise NotImplementedError() \ No newline at end of file diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py new file mode 100644 index 000000000..1b022f38c --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaInferenceForwards + +__all__ = ['LlamaInferenceForwards'] \ No newline at end of file diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py new file mode 100644 index 000000000..df1b99769 --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -0,0 +1,321 @@ +from typing import List, Optional, Tuple + +import torch +import numpy as np +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, +) +from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest +from colossalai.kernel.triton.context_attention import llama_context_attn_fwd +from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd +from typing import List, Optional, Tuple +from transformers.modeling_outputs import BaseModelOutputWithPast + +class LlamaInferenceForwards: + """ + This class holds forwards for llama inference. + """ + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + + batch_size = input_ids.shape[0] # input_ids.shape[0] + + # infer_state = BatchInferState(batch_size, input_ids.shape[1]) + # infer_state.batch_size = batch_size + # # NOTE: dummy implementation here for testing, just assume all inputs same length + # infer_state.block_loc = self.block_loc + # infer_state.start_loc = self.start_loc + # infer_state.seq_len = self.seq_len + # infer_state.max_len_in_batch = self.max_len_in_batch + + infer_state = self.infer_state + b_seq_len_numpy = infer_state.seq_len.cpu().numpy() + position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) + for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + + # this equals + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + # TODO dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + # FIXME: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if use_cache and seq_length != 1: + # NOTE assuem prefill stage + # allocate memory block + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index) + else: + # TODO handle the condition that no contiguous memory presents + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print(f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}") + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + infer_state.decode_layer_id = 0 + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] if past_key_values is not None else None + # NOTE: modify here for passing args to decoder layer + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + infer_state.decode_layer_id += 1 + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + hidden_states = self.norm(hidden_states) + next_cache = next_decoder_cache if use_cache else None + + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + @staticmethod + def llama_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + + @staticmethod + def llama_flash_attn_kvcache_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + assert use_cache is True, "use_cache should be set to True using this llama attention" + + bsz, q_len, _ = hidden_states.size() + + # TODO might think about better way to handle transposed k and v + # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] + # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states_transposed = key_states.transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + + # cos, sin = self.rotary_emb(value_states_transposed, seq_len=kv_seq_len) + cos ,sin = infer_state.position_cos, infer_state.position_sin + + cos_sin_cache = torch.cat((cos, sin), dim=-1) + + from vllm.pos_encoding_ops import rotary_embedding_neox + + rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) + + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + num_heads = key_buffer.shape[2] + head_dim = key_buffer.shape[3] + key_buffer = key_buffer.view(-1, num_heads, head_dim) + value_buffer = value_buffer.view(-1, num_heads, head_dim) + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + + # copy key and value calculated in current step to memory manager + if infer_state.is_context_stage: + _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, infer_state.cache_manager) + else: + _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index, infer_state.cache_manager) + + # this is worse than destcopy + # torch.Tensor.copy_(infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],key_states) + # torch.Tensor.copy_(infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],value_states) + + # FIXME might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len + + query_states = query_states.transpose(1, 2) + + if infer_state.is_context_stage: + # first token generation + + # attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states, + # key_states, + # value_states, + # 0, + # 1/math.sqrt(self.head_dim), + # causal, + # False) + + attn_output = torch.empty_like(query_states) + + # calcu_shape for context_attention_fwd + calcu_shape1 = (-1, self.num_heads, self.head_dim) + + llama_context_attn_fwd(query_states.view(calcu_shape1), + key_states.view(calcu_shape1), + value_states.view(calcu_shape1), + attn_output.view(calcu_shape1), + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length) + else: + # second token and follows + # kv = torch.stack((key_states, value_states), dim=2) + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + token_attention_fwd(query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + # return past_key_value as None + return attn_output, None, None + + \ No newline at end of file diff --git a/colossalai/inference/tensor_parallel/pollcies/__init__.py b/colossalai/inference/tensor_parallel/pollcies/__init__.py new file mode 100644 index 000000000..d92a3e84d --- /dev/null +++ b/colossalai/inference/tensor_parallel/pollcies/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaModelInferPolicy + +__all__ = ['LlamaModelInferPolicy'] \ No newline at end of file diff --git a/colossalai/inference/tensor_parallel/pollcies/llama.py b/colossalai/inference/tensor_parallel/pollcies/llama.py new file mode 100644 index 000000000..570e10ba3 --- /dev/null +++ b/colossalai/inference/tensor_parallel/pollcies/llama.py @@ -0,0 +1,35 @@ +from functools import partial + +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + +from ..modeling.llama import LlamaInferenceForwards + +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + policy = super().module_policy() + self.shard_config._infer() + + # example for replace layer or decoder + # if self.shard_config.enable_flash_attention: + # policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ + # 'forward': get_llama_flash_attention_forward(), + # }) + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer) + + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) + + return policy \ No newline at end of file diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a18d700f9..294ab8770 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -391,84 +391,6 @@ class LlamaPipelineForwards: return {'hidden_states': hidden_states} -class LlamaInferenceForwards: - """ - This class holds forwards for llama inference. - """ - - @staticmethod - def llama_model_forward( - self: LlamaModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[ - torch.LongTensor] = None, # TODO: this can also be removed if we got sin,cos cached in inferinfo - past_key_values: Optional[List[ - torch.FloatTensor]] = None, #TODO: maybe removed after memory cache manager is done. - inputs_embeds: Optional[torch.FloatTensor] = None, - return_dict: Optional[bool] = None, - inferinfo=None, - ): - # only keep the basic items - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, - past_key_values_length) - - hidden_states = inputs_embeds - - for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - ) - - hidden_states = layer_outputs[0] - - hidden_states = self.norm(hidden_states) - - if not return_dict: - return hidden_states - return BaseModelOutputWithPast(last_hidden_state=hidden_states,) - - def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 0ffa7fbee..aa100a065 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -140,11 +140,15 @@ _INFER_POLICY_LIST = { } -def import_policy(policy_location: PolicyLocation) -> Policy: +def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy: """ Dynamically import a Policy class based on the policy location. """ - module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" + + if inference_only: + module_name = f"colossalai.inference.tensor_parallel.pollcies.{policy_location.file_name}" + else: + module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" module = importlib.import_module(module_name) return getattr(module, policy_location.class_name) @@ -181,5 +185,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" ) else: - policy = import_policy(policy_location) + policy = import_policy(policy_location, inference_only) return policy() diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 0de2752cf..5ee95f3be 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -7,7 +7,7 @@ from torch.nn import Module from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward +from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] @@ -263,21 +263,3 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in llama for sequence classification model""" return [] - - -class LlamaModelInferPolicy(LlamaPolicy): - - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel - policy = super().module_policy() - # configure default shard config for inference - self.shard_config._infer() - - infer_forward = LlamaInferenceForwards.llama_model_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) - - return policy diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 09a81ef7f..89646ca9f 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -2,40 +2,72 @@ import os import pytest import torch -from torch import distributed as dist +import numpy as np import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_infer._utils import build_model, run_infer +from transformers import LlamaForCausalLM, LlamaTokenizer +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.inference.tensor_parallel.engine import TPInferEngine +import torch.distributed as dist os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +TPSIZE = 2 +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config,"max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config,"max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) -def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config): - org_model, sharded_model = build_model(model_fn, **test_config) - - org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn) - - print('original output', org_output[0]) - print('infer output', infer_output[0]) - + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return @parameterize('test_config', [{ - 'enable_flash_attention': False, + 'tp_size': TPSIZE, }]) def run_llama_test(test_config): + + llama_model_path = "/data/scratch/llama-7b-hf" + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) + init_to_get_rotary(model.model, base=10000) + model = model.half() + model.to(torch.cuda.current_device()) + + text = "Introduce some landmarks in Beijing" + input_ids = tokenizer.encode(text, return_tensors='pt') + # pg_mesh = ProcessGroupMesh(1, 1, test_config["tp_size"]) + + infer_engine = TPInferEngine(model.half(), 4, 12, 8) + shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) + shardformer = ShardFormer(shard_config=shard_config) - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + infer_engine.prepare_with_shard_config(shard_config) + infer_engine.shard_model_by(shardformer) - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name != "transformers_llama": - continue - check_infer(model_fn, data_gen_fn, output_transform_fn, test_config) - torch.cuda.empty_cache() + generate_kwargs = dict(do_sample=False) + outputs = infer_engine.generate(input_ids, generate_kwargs) + + print("outputs: ", outputs) + + output_text = tokenizer.decode(outputs[0]) + print(output_text) def check_llama(rank, world_size, port): @@ -48,7 +80,7 @@ def check_llama(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_llama(): - spawn(check_llama, 1) + spawn(check_llama, TPSIZE) if __name__ == "__main__":