diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index e99eb364e..c4adba82b 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -1,7 +1,6 @@ """ -Our config consists of two parts: +Our config consists of one part: 1. inference_config: configs for inference, it is a unified api that wraps all the configs for inference. - 2. generation_config: configs for generation, it is inherited from huggingface. """ import logging diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 3aad5ad97..7ac804c1c 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -46,6 +46,7 @@ class InferenceEngine: ) -> None: assert inference_config, "Please provide inference_config." self.tokenizer = tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token self.inference_config = inference_config self.model_config = model.config @@ -169,9 +170,7 @@ class InferenceEngine: if prompts_token_ids is None: assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." - prompts_token_ids = [] - for prompt in prompts: - prompts_token_ids.append(self.tokenizer.encode(prompt)) + prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"] prompts_num = len(prompts_token_ids) @@ -212,11 +211,14 @@ class InferenceEngine: self.logger.info("Running generation step") output_list = [] - self.request_handler.schedule() + batch, k_cache, v_cache = self.request_handler.schedule() - # Uncomment if the development of RequestHandler is completed. - # logits = self.model(batch) - # self.request_handler.search_tokens(logits, self.generation_config) + logits = self.model( + batch, + k_cache, + v_cache, + ) + self.request_handler.search_tokens(logits, self.generation_config) finished_sequences = self.request_handler.update() diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index bcd213013..50eac0854 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -110,6 +110,10 @@ class KVCacheManager: """Get the number of available cache blocks.""" return self._available_blocks + def get_kv_cache(self): + """Get k_cache and v_cache""" + return self._kv_cache[0], self._kv_cache[1] + def get_max_blocks_per_sequence(self) -> int: """Get the maximum number of blocks that can be allocated for a single sequence.""" # TODO Consider removing this function as we plan to implement "half-dynamic" batching in schduler/request handler, diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py new file mode 100644 index 000000000..6c1d844d0 --- /dev/null +++ b/colossalai/inference/modeling/models/llama.py @@ -0,0 +1,208 @@ +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py +from typing import List, Optional, Tuple + +import torch +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel + +from colossalai.inference.struct import BatchInfo + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def llama_causal_lm_forward( + self: LlamaForCausalLM, + batch: BatchInfo = None, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, +): + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + hidden_states = llama_model_forward( + self.model, + batch=batch, + k_caches=k_caches, + v_caches=v_caches, + ) + logits = self.lm_head(hidden_states) + + return logits + + +def llama_model_forward( + self: LlamaModel, + batch: BatchInfo = None, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, +): + input_ids = batch.get_batch_inputs() + block_tables = batch.get_block_table_tensor() + sequence_lengths = batch.get_sequence_lengths() + + seq_length = input_ids.shape[1] + device = input_ids.device + + past_key_values_length = len(block_tables.shape[1]) + + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + + hidden_states = self.embed_tokens(input_ids) + + for layer_id, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + position_ids=position_ids, + block_tables=block_tables, + k_cache=k_caches[layer_id], + v_cache=v_caches[layer_id], + is_prompts=batch.is_prompts, + sequence_lengths=sequence_lengths, + ) + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +def llama_decoder_layer_forward( + self: LlamaDecoderLayer, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + block_tables: torch.Tensor = None, + k_cache: torch.Tensor = None, + v_cache: torch.Tensor = None, + is_prompts: bool = True, + sequence_lengths: int = None, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + block_tables=block_tables, + k_cache=k_cache, + v_cache=v_cache, + is_prompts=is_prompts, + sequence_lengths=sequence_lengths, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward +def llama_attn_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + block_tables: torch.Tensor = None, + k_cache: torch.Tensor = None, + v_cache: torch.Tensor = None, + is_prompts: bool = True, + sequence_lengths: int = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + block_tables.shape[1] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + query_states = query_states.view(-1, self.num_heads, self.head_dim) + key_states = key_states.view(-1, self.num_heads, self.head_dim) + value_states = value_states.view(-1, self.num_heads, self.head_dim) + + block_size = k_cache.shape[-1] + + memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size) + + if is_prompts: + attn_output = context_attention_unpadded( + query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size + ) + else: + attn_output = torch.empty(bsz, self.num_heads, self.head_dim) + decoding_attention( + query_states, + k_cache, + v_cache, + block_tables, + sequence_lengths, + attn_output, + block_tables.shape[1], + block_size, + ) + + attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +def memcpy_to_block(key, value, k_cache, v_cache, block_tables, block_size): + block_table_list = block_tables.tolist() + batch_size, seq_len, num_heads, head_dim = key + + reshape_key = key.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1) + reshape_value = value.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1) + if seq_len == 1: + for i in range(batch_size): + k_cache[block_table_list[i][-1], :] = reshape_key[i] + v_cache[block_table_list[i][-1], :] = reshape_value[i] + else: + for i in range(batch_size): + k_cache[block_table_list[i], :] = reshape_key[i] + v_cache[block_table_list[i], :] = reshape_value[i] diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index f0725dc80..3c616c6ce 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -183,13 +183,16 @@ class BatchInfo: return cls(sequences_set=sequences_set) - def get_block_table_tensor(self): + def get_block_table_tensor(self) -> None: tesnor_list = [] + block_table = None for seq in self.sequences_set: block_table = seq.block_table assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table." tesnor_list.append(seq.block_table) - return torch.concat(tesnor_list) + assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first." + block_table = torch.concat(tesnor_list) + return block_table def clear_batch(self) -> None: """ @@ -271,3 +274,38 @@ class BatchInfo: Get batch_size of this batch """ return len(self.sequences_set) + + def get_batch_inputs(self) -> torch.LongTensor: + """ + Get bacth inputs for forward inference computation. + """ + input_list = [] + + for seq in self.sequences_set: + if self.is_prompts: + input_list.append(seq.input_token_id) + else: + input_list.append([seq.output_token_id[-1]]) + + return torch.tensor(input_list, dtype=torch.long) + + def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]: + """ + Flattening the input tokens. + """ + input_list = [] + for seq in self.sequences_set: + if self.is_prompts: + input_list.extend(seq.input_token_id) + else: + input_list.append(seq.output_token_id[-1]) + return torch.tensor(input_list, dtype=torch.long) + + def get_sequence_lengths(self): + """ + Get the input_len of each sentence in this batch. + """ + len_list = [] + for seq in self.sequences_set: + len_list.append(seq.get_sentence_len()) + return torch.tensor(len_list, dtype=torch.int)