[Inference/Spec-Dec] Merge pull request #5565 from hpcaitech/feat/speculative-decoding

Add Speculative Decoding and GLIDE Spec-Dec
This commit is contained in:
Yuanheng Zhao 2024-04-10 18:39:27 +08:00 committed by GitHub
commit 25928d8496
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1690 additions and 194 deletions

View File

@ -133,7 +133,7 @@ We offer 3 main sampling strategies now (i.e. `greedy sample`, `multinomial samp
| Model | KV Cache | Paged Attention | Kernels | Tensor Parallelism | Speculative Decoding | | Model | KV Cache | Paged Attention | Kernels | Tensor Parallelism | Speculative Decoding |
| - | - | - | - | - | - | | - | - | - | - | - | - |
| Llama | ✅ | ✅ | ✅ | 🔜 | 🔜 | | Llama | ✅ | ✅ | ✅ | 🔜 | |
Notations: Notations:
@ -148,7 +148,7 @@ Notations:
- [x] High-Performance Kernels - [x] High-Performance Kernels
- [x] Llama Modelling - [x] Llama Modelling
- [x] User Documentation - [x] User Documentation
- [ ] Speculative Decoding - [x] Speculative Decoding
- [ ] Tensor Parallelism - [ ] Tensor Parallelism
- [ ] Beam Search - [ ] Beam Search
- [ ] Early stopping - [ ] Early stopping

View File

@ -42,6 +42,9 @@ class BatchBucket:
self.device = device or get_current_device() self.device = device or get_current_device()
self.dtype = dtype self.dtype = dtype
self._use_spec_dec = False
self._num_tokens_to_verify = None
self._current_batch_size = 0 self._current_batch_size = 0
self._sequences_dict = dict() self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size) self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
@ -88,6 +91,27 @@ class BatchBucket:
== torch.nonzero(self._block_tables[:, 0] >= 0).numel() == torch.nonzero(self._block_tables[:, 0] >= 0).numel()
) )
@property
def use_spec_dec(self) -> bool:
return self._use_spec_dec
@property
def num_tokens_to_verify(self) -> int:
return self._num_tokens_to_verify
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
and let the main model verifies tokens in parallel.
"""
self._use_spec_dec = True
self._num_tokens_to_verify = num_tokens_to_verify
def reset_use_spec_dec(self) -> None:
"""Reset the usage of speculative decoding for the batch bucket"""
self._use_spec_dec = False
self._num_tokens_to_verify = None
def _make_compact(self) -> None: def _make_compact(self) -> None:
# Clean and Compress the batch based on its sequences dict. # Clean and Compress the batch based on its sequences dict.
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors. # Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
@ -347,6 +371,23 @@ class BatchBucket:
seq.check_finish() seq.check_finish()
self._sequence_lengths[: self.current_batch_size] += 1 self._sequence_lengths[: self.current_batch_size] += 1
def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
"""Revoke the last n output tokens of the sequences in the batch
Args:
n_tokens (int): The number of output tokens to revoke from each sequence.
It does not count in the context tokens (input tokens).
n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1.
For now, speculative decoding only supports batch size 1.
"""
if n_tokens >= 1:
seqs_iter = iter(self._sequences_dict.items())
for _ in range(n_seqs):
seq_id, seq = next(seqs_iter)
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n_tokens]
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
"""Clear all the sequences in the batch. """Clear all the sequences in the batch.
@ -401,6 +442,21 @@ class BatchBucket:
return True return True
return False return False
def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor:
# Used for main model verification in **Decoding Stage**
# `n` is the number of tokens to be verified,
# and so that prepare the last `n` tokens of each sequence as the inputs
assert len(self._sequences_dict) > 0, "No sequence in the batch"
assert all(
seq.output_len >= n for seq in self._sequences_dict.values()
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.output_token_id[-n:])
return torch.tensor(out_li, dtype=torch.long, device=self.device)
# For compatibility # For compatibility
def get_1D_inputs(self) -> torch.Tensor: def get_1D_inputs(self) -> torch.Tensor:
assert len(self._sequences_dict) > 0, "No sequence in the batch" assert len(self._sequences_dict) > 0, "No sequence in the batch"
@ -411,8 +467,6 @@ class BatchBucket:
seq.output_len == 0 for seq in self._sequences_dict.values() seq.output_len == 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch" ), "Sequence stage (Prefill/Decoding) must be the same in the batch"
out_li = [] out_li = []
num_tokens = torch.sum(self._sequence_lengths)
out = torch.empty([num_tokens], dtype=torch.long)
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids: for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id] seq: Sequence = self._sequences_dict[seq_id]
@ -420,6 +474,10 @@ class BatchBucket:
return torch.tensor(out_li, dtype=torch.long, device=self.device) return torch.tensor(out_li, dtype=torch.long, device=self.device)
else: else:
# Assume decoding stage # Assume decoding stage
if self.use_spec_dec:
# For Speculative Decoding
# the number of tokens to be verified in parallel plus the correct token in the last step
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
assert all( assert all(
seq.output_len > 0 for seq in self._sequences_dict.values() seq.output_len > 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch" ), "Sequence stage (Prefill/Decoding) must be the same in the batch"

View File

@ -26,7 +26,7 @@ _ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
_DEFAULT_PROMPT_TEMPLATES = { _DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]", "llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"vicuna": "USER: {input_text}\n\nASSISTANT: ", "vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
} }
@ -46,6 +46,8 @@ class InputMetaData:
head_dim (int, optional): Head dimension. Defaults to 32. head_dim (int, optional): Head dimension. Defaults to 32.
high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False. high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False.
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32. dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
use_spec_dec (bool): Indicate whether to use speculative decoding.
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
""" """
block_tables: torch.Tensor = None block_tables: torch.Tensor = None
@ -59,9 +61,22 @@ class InputMetaData:
head_dim: int = 32 head_dim: int = 32
high_precision: bool = False high_precision: bool = False
dtype: torch.dtype = torch.float32 dtype: torch.dtype = torch.float32
use_spec_dec: bool = False
num_tokens_to_verify: int = 0
def __repr__(self) -> str: def __repr__(self) -> str:
return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})" return (
f"InputMetaData(block_tables={self.block_tables}, "
f"sequence_lengths={self.sequence_lengths}, "
f"fd_inter_tensor={self.fd_inter_tensor}, "
f"batch_size={self.batch_size}, "
f"is_prompts={self.is_prompts}, "
f"use_cuda_kernel={self.use_cuda_kernel}, "
f"use_cuda_graph={self.use_cuda_graph}, "
f"kv_seq_len={self.kv_seq_len}, "
f"use_spec_dec={self.use_spec_dec}, "
f"num_tokens_to_verify={self.num_tokens_to_verify})"
)
@dataclass @dataclass
@ -84,6 +99,8 @@ class InferenceConfig:
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None. top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None. top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None. min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16. block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1. tp_size (int): Tensor parallel size, defaults to 1.
pp_size (int): Pipeline parallel size, defaults to 1. pp_size (int): Pipeline parallel size, defaults to 1.
@ -118,6 +135,10 @@ class InferenceConfig:
top_p: Optional[float] = None top_p: Optional[float] = None
min_p: Optional[float] = None min_p: Optional[float] = None
# speculative decoding configs
max_n_spec_tokens: int = 5
glimpse_large_kv: bool = False
# paged attention configs # paged attention configs
block_size: int = 16 block_size: int = 16

View File

@ -12,6 +12,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -52,19 +53,27 @@ class InferenceEngine:
verbose: bool = False, verbose: bool = False,
model_policy: Policy = None, model_policy: Policy = None,
) -> None: ) -> None:
assert inference_config, "Please provide inference_config."
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
self.inference_config = inference_config self.inference_config = inference_config
self.model_config = model.config self.model_config = model.config
self.model = model
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.dtype = inference_config.dtype self.dtype = inference_config.dtype
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config.to_generation_config(self.model_config)
self.high_precision = inference_config.high_precision self.high_precision = inference_config.high_precision
model = model.eval() self._verify_args()
model = model.cuda()
model.to(self.dtype) self.generation_config = inference_config.to_generation_config(self.model_config)
model.eval()
model = model.to(self.dtype)
model = model.to(self.device)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
if model_policy is None: if model_policy is None:
if self.inference_config.pad_input: if self.inference_config.pad_input:
@ -174,21 +183,18 @@ class InferenceEngine:
if self.verbose: if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
def _verify_config(self) -> None: def _verify_args(self) -> None:
""" """Verify the input args"""
Verify the input config if not isinstance(self.inference_config, InferenceConfig):
""" raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.model, nn.Module): if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance( if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
self.tokenizer, PreTrainedTokenizer
):
raise TypeError( raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
) )
assert ( if self.model.__class__.__name__ not in _supported_models:
self.model.__class__.__name__ in _supported_models raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")
), f"Model {self.model.__class__.__name__} is not supported."
def _shardformer( def _shardformer(
self, self,
@ -224,6 +230,199 @@ class InferenceEngine:
shard_model, _ = shardformer.optimize(model, model_policy) shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model return shard_model
def enable_spec_dec(
self,
drafter_model: nn.Module = None,
n_spec_tokens: int = None,
use_glide_drafter: bool = False,
) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
If True, the drafter model will be replaced by a glide model.
```python
...
engine = InferenceEngine(model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
engine.generate(...) # Speculative Decoding
engine.disable_spec_dec()
engine.generate(...) # Normal generation
engine.enable_spec_dec()
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
self.n_spec_tokens = n_spec_tokens
if drafter_model is not None:
assert isinstance(drafter_model, nn.Module)
# overwrite the drafter, if exists
self.clear_spec_dec()
self.drafter_model = drafter_model
self.drafter = Drafter(
self.drafter_model,
self.tokenizer,
device=self.device,
dtype=self.dtype,
)
# check if the provided drafter model is compatible with GLIDE structure
# when `use_glide_drafter` is set to True
if (
use_glide_drafter
and hasattr(drafter_model, "model")
and hasattr(drafter_model.model, "layers")
and hasattr(drafter_model.model.layers[0], "cross_attn")
):
self.use_glide = use_glide_drafter
elif use_glide_drafter:
self.logger.warning(
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
f"but the provided drafter model is not compatible with GLIDE structure."
f"Falling back to use the default drafter model (non-GLIDE)."
)
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
self.request_handler.unset_spec_dec_mode()
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_glide = False
self.use_spec_dec = False
def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.use_spec_dec:
self.disable_spec_dec()
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_glide = False
self.use_spec_dec = False
def steps_spec_dec(self) -> List[Sequence]:
"""
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
with many steps of speculating by a drafter model as well as verifying by a main model.
Returns:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
already_allocated_kv_len = batch.seq_lengths[0].item()
input_token_ids = batch.get_1D_inputs_spec_dec(1)
finished_sequences = self.request_handler.update()
while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
# 3. Decoding - Drafter model speculates `n` tokens
glide_input = None
if self.use_glide:
glide_input = GlideInput(
batch.get_block_table_tensor(),
self.k_cache[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
)
drafter_out = self.drafter.speculate(
input_token_ids,
self.n_spec_tokens,
drafter_past_key_values,
glide_input=glide_input,
)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
drafter_spec_length = drafter_out.speculated_length
for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
cur_length = batch.seq_lengths[0].item()
if already_allocated_kv_len < cur_length:
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
already_allocated_kv_len = cur_length
# 4. Decoding - Main model verifies `n` tokens in parallel
if drafter_spec_length < batch.num_tokens_to_verify:
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(
drafter_past_key_values, drafter_spec_length - n_matches - 1
)
# prepare inputs for the next round of speculation
n = 1 if n_matches < drafter_spec_length else 2
input_token_ids = batch.get_1D_inputs_spec_dec(n)
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break
# Reset back the number of speculated tokens of the batch,
# this is used to handle the last round of speculation, in which case the number of speculated tokens
# by the drafter is less than the number of speculated tokens set to the engine.
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
return finished_sequences
def generate( def generate(
self, self,
prompts: List[str] = None, prompts: List[str] = None,
@ -246,7 +445,6 @@ class InferenceEngine:
List[str]: Inference result returned by one generation. List[str]: Inference result returned by one generation.
""" """
with torch.inference_mode(): with torch.inference_mode():
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None: if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
@ -257,8 +455,13 @@ class InferenceEngine:
if generation_config is not None: if generation_config is not None:
self.generation_config = generation_config self.generation_config = generation_config
while self.request_handler.check_unfinished_seqs(): if self.use_spec_dec:
output_seqs_list += self.step() assert self.drafter is not None, "Drafter Model is not initialized."
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.steps_spec_dec()
else:
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
@ -368,18 +571,19 @@ class InferenceEngine:
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
input_ids = batch.get_1D_inputs() input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths() sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts: if batch.is_prompts:
output_tensor = torch.zeros( n_tokens = sequence_lengths.sum().item()
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim),
dtype=batch.dtype,
device=batch.device,
)
else: else:
output_tensor = torch.zeros( n_tokens = batch.current_batch_size
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device if batch.use_spec_dec:
) n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
n_tokens = n_tokens * batch.current_batch_size
output_tensor = torch.zeros(
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference # only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False use_cuda_graph = False
@ -398,6 +602,8 @@ class InferenceEngine:
kv_seq_len=sequence_lengths.max().item(), kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim, head_dim=batch.head_dim,
dtype=batch.dtype, dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
) )
return input_ids, output_tensor, input_meta_data return input_ids, output_tensor, input_meta_data
@ -428,7 +634,8 @@ class InferenceEngine:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits) next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()

View File

@ -134,8 +134,12 @@ class RequestHandler:
if fd_inter_tensor._tensors_initialized: if fd_inter_tensor._tensors_initialized:
fd_inter_tensor._reset() fd_inter_tensor._reset()
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
max_n_tokens = self.max_batch_size
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
fd_inter_tensor.initialize( fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size, max_batch_size=max_n_tokens,
num_attn_heads=model_config.num_attention_heads, num_attn_heads=model_config.num_attention_heads,
kv_max_split_num=kv_max_split_num, kv_max_split_num=kv_max_split_num,
head_dim=head_dim, head_dim=head_dim,
@ -177,6 +181,14 @@ class RequestHandler:
def get_kvcache(self): def get_kvcache(self):
return self.cache_manager.get_kv_cache() return self.cache_manager.get_kv_cache()
def set_spec_dec_mode(self, n_spec_tokens: int):
self.prefill_bb.set_use_spec_dec(n_spec_tokens)
self.running_bb.set_use_spec_dec(n_spec_tokens)
def unset_spec_dec_mode(self):
self.prefill_bb.reset_use_spec_dec()
self.running_bb.reset_use_spec_dec()
def schedule(self): def schedule(self):
""" """
The main logic of request handler. The main logic of request handler.
@ -204,7 +216,11 @@ class RequestHandler:
lst.remove(seq) lst.remove(seq)
if self.running_list.ready_for_prefill(): if self.running_list.ready_for_prefill():
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size) num_seqs_to_add = min(self.running_list.prefill_seq_num, self.prefill_bb.available_batch_size)
# overwrite the number of sequences to add to 1 if use_spec_dec is enabled
# TODO (zhaoyuanheng): support speculative decoding for batch size > 1
if self.prefill_bb.use_spec_dec:
num_seqs_to_add = 1
for seq in self.running_list.prefill[:num_seqs_to_add]: for seq in self.running_list.prefill[:num_seqs_to_add]:
seq.mark_running() seq.mark_running()
@ -230,6 +246,13 @@ class RequestHandler:
return self.running_bb return self.running_bb
def allocate_batch_spec_dec(self, batch: BatchBucket, n: int):
assert batch.use_spec_dec
if n > 0:
self.cache_manager.allocate_n_tokens_from_block_tables(
batch.block_tables, batch.seq_lengths, batch.current_batch_size, n=n
)
def add_sequence(self, req: Sequence): def add_sequence(self, req: Sequence):
""" """
Add the request to waiting list. Add the request to waiting list.
@ -282,13 +305,21 @@ class RequestHandler:
return sample_tokens return sample_tokens
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig): def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
if ( if (
sequence.output_token_id[-1] == generation_config.eos_id sequence.output_token_id[-1] == generation_config.eos_token_id
or sequence.output_len >= generation_config.max_output_len or sequence.output_len >= generation_config.max_length
): ):
sequence.mark_finished() sequence.mark_finished()
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
for seq in batch.seqs_li:
if (
seq.output_token_id[-1] == generation_config.eos_token_id
or seq.output_len >= generation_config.max_length
):
seq.mark_finished()
def check_unfinished_seqs(self) -> bool: def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty() return self._has_waiting() or not self.running_list.is_empty()
@ -309,9 +340,20 @@ class RequestHandler:
# sample the next tokens # sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config) sample_tokens = self._sample(probs, logprobs, generation_config)
return sample_tokens
def append_next_tokens(self, sample_tokens: torch.Tensor):
assert sample_tokens.dim() == 1
n_elements = sample_tokens.size(0)
if not self.prefill_bb.is_empty: if not self.prefill_bb.is_empty:
assert (
self.prefill_bb.current_batch_size == n_elements
), f"Incompatible size: {n_elements} tokens to append while prefill batch size {self.prefill_bb.current_batch_size}"
self.prefill_bb.append_batch_tokens(sample_tokens) self.prefill_bb.append_batch_tokens(sample_tokens)
else: else:
assert (
self.running_bb.current_batch_size == n_elements
), f"Incompatible size: {n_elements} tokens to append while running batch size {self.running_bb.current_batch_size}"
self.running_bb.append_batch_tokens(sample_tokens) self.running_bb.append_batch_tokens(sample_tokens)
def update(self): def update(self):

View File

@ -349,6 +349,26 @@ class KVCacheManager:
return seqs_to_recycle return seqs_to_recycle
def allocate_n_tokens_from_block_tables(
self,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
bsz: int,
n: int,
) -> List[int]:
"""Allocate logical cache blocks for `n` new tokens for a batch of sequences during decoding stage."""
assert block_tables.dim() == 2
assert context_lens.dim() == 1
bsz = block_tables.size(0) if bsz is None else bsz
assert bsz == 1, "Support bsz 1 for now" # TODO support bsz > 1
seqs_to_recycle = []
for i in range(n):
seqs_to_recycle += self.allocate_tokens_from_block_tables(block_tables, context_lens - n + i + 1, bsz)
return seqs_to_recycle
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int: def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
"""Allocate space asked on a single block in the block table, specified by the provided position id, """Allocate space asked on a single block in the block table, specified by the provided position id,
and updates the provided block table with the allocated block. and updates the provided block table with the allocated block.
@ -420,9 +440,7 @@ class KVCacheManager:
Returns: Returns:
The remaining space required to be allocated (in other blocks). The remaining space required to be allocated (in other blocks).
""" """
assert ( assert block.available_space > 0, f"Found no available space left in the chosen block {block}."
block.available_space > 0
), "Tried to allocate some space but found no available space left in chosen block."
space_to_allocate = min(block.available_space, space_asked) space_to_allocate = min(block.available_space, space_asked)
block.allocate(space_to_allocate) block.allocate(space_to_allocate)
return space_asked - space_to_allocate return space_asked - space_to_allocate

View File

@ -0,0 +1,475 @@
# This is modified from huggingface transformers
# https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/llama/modeling_llama.py
import warnings
from types import MethodType
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaForCausalLM,
LlamaLinearScalingRotaryEmbedding,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
from colossalai.inference.spec import GlideInput
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_single_rotary_pos_emb(q, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
return q_embed
def glide_llama_causal_lm_forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
glide_input: Optional[GlideInput] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if not return_dict:
output = (logits,) + outputs[1:]
return output
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def glide_llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
glide_input: GlideInput = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
# GlideLlamaDecoderLayer
layer_outputs = decoder_layer(
hidden_states,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class GlideLlamaConfig(LlamaConfig):
"""Configuration class with specific arguments used by GLIDE llama model as a drafter"""
def __init__(
self,
large_hidden_size=4096,
large_num_attention_heads=32,
**kwargs,
):
super().__init__(**kwargs)
self.large_hidden_size = large_hidden_size
self.large_num_attention_heads = large_num_attention_heads
class LlamaCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: GlideLlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# large model (verifier) configs
self.large_hidden_size = config.large_hidden_size
self.large_num_heads = config.large_num_attention_heads
self.large_head_dim = self.large_hidden_size // self.large_num_heads
self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Optional[torch.Tensor]:
bsz, q_len, _ = hidden_states.size()
block_tables = glide_input.block_tables
large_k_cache = glide_input.large_k_cache
large_v_cache = glide_input.large_v_cache
sequence_lengths = glide_input.sequence_lengths
cache_block_size = large_k_cache.size(-2)
query_states = self.q_proj(hidden_states)
kv_seq_len = sequence_lengths.max().item()
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
# for RoPE
cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32)
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=large_k_cache,
v_cache=large_v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=cache_block_size,
max_seq_len_in_batch=kv_seq_len,
) # attn_output: [bsz * q_len, num_heads * head_dim]
attn_output = attn_output.reshape(bsz, q_len, self.large_hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
# A class to be used to replace LlamaDecoderLayer in a Llama Model as Drafter in speculative decoding.
# Refer to GLIDE with a CAPE https://arxiv.org/pdf/2402.02082.pdf
class GlideLlamaDecoderLayer(nn.Module):
def __init__(self, config: GlideLlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.cross_attn = LlamaCrossAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@staticmethod
def from_native_module(module: LlamaDecoderLayer, *args, **kwargs) -> "GlideLlamaDecoderLayer":
"""Build a GlideLlamaDecoderLayer from a native LlamaDecoderLayer"""
config: LlamaConfig = module.mlp.config # XXX
layer_idx = module.self_attn.layer_idx
glide_config = GlideLlamaConfig(**config.to_dict())
glide_decoder_layer = GlideLlamaDecoderLayer(glide_config, layer_idx=layer_idx)
return glide_decoder_layer
def forward(
self,
hidden_states: torch.Tensor,
glide_input: GlideInput = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
hidden_states = residual + hidden_states
curr_q_len = hidden_states.size(1)
# Cross attention
if glide_input is None or not glide_input.glimpse_ready:
warnings.warn(
"Data used for glimpsing the past KV caches of the main model (verifier) is not complete. "
"Fall back to normal decoder layer modeling (drafter). "
"This might lead to incorrect results when using the Glide Models for speculative decoding."
)
elif curr_q_len == 1:
# Notice that we skip prefill stage
# always use the output of the main model as the inputs for the next round of speculation
residual = hidden_states
hidden_states = self.cross_attn(
hidden_states=hidden_states,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=True,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if use_cache:
outputs += (present_key_value,)
return outputs
class GlideLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, config: GlideLlamaConfig):
super().__init__(config)
self.config = config
bound_method = MethodType(glide_llama_causal_lm_forward, self)
setattr(self, "forward", bound_method)
bound_method = MethodType(glide_llama_model_forward, self.model)
model = getattr(self, "model")
setattr(model, "forward", bound_method)
replaced_layers = nn.ModuleList(
[GlideLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
setattr(model, "layers", replaced_layers)

View File

@ -18,6 +18,7 @@ from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import ( from colossalai.kernel.triton import (
context_attention_unpadded, context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding, decoding_fused_rotary_embedding,
flash_decoding_attention, flash_decoding_attention,
get_xine_cache, get_xine_cache,
@ -84,9 +85,9 @@ def llama_model_forward(
"""This function will replace the forward function of LlamaModel. """This function will replace the forward function of LlamaModel.
Args: Args:
batch (BatchInfo): It stores the necessary input information for this inference. batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache. k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache. v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
""" """
block_tables = inputmetadata.block_tables block_tables = inputmetadata.block_tables
@ -100,8 +101,31 @@ def llama_model_forward(
if batch_size >= 32 and kv_seq_len > 512: if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False use_cuda_kernel = False
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
# during speculative-decoding (`q_len > 1`)
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
if inputmetadata.use_spec_dec and use_cuda_kernel:
use_cuda_kernel = False
logger.warning("CUDA kernel is disabled for speculative-decoding.")
hidden_states = self.embed_tokens(input_tokens_ids) hidden_states = self.embed_tokens(input_tokens_ids)
if use_cuda_kernel: cu_seqlens = None
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
if inputmetadata.use_spec_dec:
# For speculative-decoding Prefill and Verifying Stage
if inputmetadata.is_prompts:
# output tensor shape is the same as normal Prefill Stage
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
else:
# the number of tokens to be verified in parallel plus the correct token in the last step
n_tokens = inputmetadata.num_tokens_to_verify + 1
assert n_tokens == hidden_states.size(0)
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
elif use_cuda_kernel:
if inputmetadata != torch.float32 and use_flash_attn2: if inputmetadata != torch.float32 and use_flash_attn2:
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
@ -113,14 +137,13 @@ def llama_model_forward(
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
) )
cos_sin = (cos, sin) cos_sin = (cos, sin)
else: else:
cu_seqlens = None
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
sm_scale = 1.0 / (inputmetadata.head_dim**0.5) sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
norm_output = torch.empty_like(hidden_states) norm_output = torch.empty_like(hidden_states)
tokens_to_verify = inputmetadata.num_tokens_to_verify if inputmetadata.use_spec_dec else None
residual = None residual = None
for layer_id, decoder_layer in enumerate(self.layers): for layer_id, decoder_layer in enumerate(self.layers):
@ -131,6 +154,8 @@ def llama_model_forward(
k_cache=k_caches[layer_id], k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id], v_cache=v_caches[layer_id],
is_prompts=inputmetadata.is_prompts, is_prompts=inputmetadata.is_prompts,
is_verifier=inputmetadata.use_spec_dec,
tokens_to_verify=tokens_to_verify,
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
cos_sin=cos_sin, cos_sin=cos_sin,
fd_inter_tensor=inputmetadata.fd_inter_tensor, fd_inter_tensor=inputmetadata.fd_inter_tensor,
@ -144,9 +169,9 @@ def llama_model_forward(
) )
if inputmetadata.is_prompts: if inputmetadata.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1) seq_len_cumsum = sequence_lengths.cumsum(dim=0)
hidden_states = hidden_states[last_token_indexs - 1].contiguous() hidden_states = hidden_states[seq_len_cumsum - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous() residual = residual[seq_len_cumsum - 1].contiguous()
norm_output = torch.empty_like(hidden_states) norm_output = torch.empty_like(hidden_states)
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel) hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
@ -164,6 +189,8 @@ def llama_decoder_layer_forward(
cos_sin: Tuple[torch.Tensor], cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors, fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True, is_prompts: bool = True,
is_verifier: bool = False,
tokens_to_verify: int = None,
kv_seq_len: int = 0, kv_seq_len: int = 0,
output_tensor: torch.Tensor = None, output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None, norm_output: torch.Tensor = None,
@ -202,10 +229,12 @@ def llama_decoder_layer_forward(
block_tables=block_tables, block_tables=block_tables,
k_cache=k_cache, k_cache=k_cache,
v_cache=v_cache, v_cache=v_cache,
is_prompts=is_prompts,
is_verifier=is_verifier,
tokens_to_verify=tokens_to_verify,
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
cos_sin=cos_sin, cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor, fd_inter_tensor=fd_inter_tensor,
is_prompts=is_prompts,
kv_seq_len=kv_seq_len, kv_seq_len=kv_seq_len,
output_tensor=output_tensor, output_tensor=output_tensor,
sm_scale=sm_scale, sm_scale=sm_scale,
@ -312,6 +341,8 @@ class NopadLlamaAttention(LlamaAttention):
cos_sin: Tuple[torch.Tensor], cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors, fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True, is_prompts: bool = True,
is_verifier: bool = False,
tokens_to_verify: int = None,
kv_seq_len: int = 0, kv_seq_len: int = 0,
output_tensor: torch.Tensor = None, output_tensor: torch.Tensor = None,
sm_scale: int = None, sm_scale: int = None,
@ -355,7 +386,7 @@ class NopadLlamaAttention(LlamaAttention):
block_size = k_cache.size(-2) block_size = k_cache.size(-2)
if is_prompts: if is_prompts:
if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
# flash attn 2 currently only supports FP16/BF16. # flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy( inference_ops.context_kv_cache_memcpy(
@ -391,6 +422,8 @@ class NopadLlamaAttention(LlamaAttention):
sm_scale=sm_scale, sm_scale=sm_scale,
) )
else: else:
q_len = tokens_to_verify + 1 if is_verifier else 1
if use_cuda_kernel: if use_cuda_kernel:
inference_ops.rotary_embedding_and_cache_copy( inference_ops.rotary_embedding_and_cache_copy(
query_states, query_states,
@ -405,17 +438,26 @@ class NopadLlamaAttention(LlamaAttention):
high_precision, high_precision,
) )
else: else:
decoding_fused_rotary_embedding( if is_verifier:
query_states, rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
key_states, copy_k_to_blocked_cache(
value_states, key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
cos_sin[0], )
cos_sin[1], copy_k_to_blocked_cache(
k_cache, value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
v_cache, )
block_tables, else:
sequence_lengths, decoding_fused_rotary_embedding(
) query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention( attn_output = flash_decoding_attention(
q=query_states, q=query_states,
k_cache=k_cache, k_cache=k_cache,
@ -428,8 +470,10 @@ class NopadLlamaAttention(LlamaAttention):
mid_output=fd_inter_tensor.mid_output, mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse, mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale, sm_scale=sm_scale,
q_len=q_len,
) )
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight) attn_output = torch.mm(attn_output, self.o_proj_weight)
return attn_output return attn_output

View File

@ -1,7 +1,9 @@
from .glide_llama import GlideLlamaModelPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy from .nopadding_llama import NoPaddingLlamaModelInferPolicy
model_policy_map = { model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy, "nopadding_llama": NoPaddingLlamaModelInferPolicy,
"glide_llama": GlideLlamaModelPolicy,
} }
__all__ = ["NoPaddingLlamaModelInferPolicy", "model_polic_map"] __all__ = ["NoPaddingLlamaModelInferPolicy", "GlideLlamaModelPolicy", "model_polic_map"]

View File

@ -0,0 +1,45 @@
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from colossalai.inference.modeling.models.glide_llama import (
GlideLlamaDecoderLayer,
glide_llama_causal_lm_forward,
glide_llama_model_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class GlideLlamaModelPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
num_layers = self.model.config.num_hidden_layers
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix=f"layers[{i}]",
target_module=GlideLlamaDecoderLayer,
)
for i in range(num_layers)
],
policy=policy,
target_key=LlamaModel,
)
self.append_or_create_method_replacement(
description={"forward": glide_llama_model_forward},
policy=policy,
target_key=LlamaModel,
)
self.append_or_create_method_replacement(
description={"forward": glide_llama_causal_lm_forward},
policy=policy,
target_key=LlamaForCausalLM,
)
return policy
def postprocess(self):
for layer in self.model.model.layers:
init_to_get_rotary(layer.cross_attn)
return self.model

View File

@ -0,0 +1,96 @@
# Speculative Decoding
Colossal-Inference supports speculative decoding using the inference engine, with optimized kernels and cache management for the main model.
Both a drafter model (small model) and a main model (large model) will be used during speculative decoding process. The drafter model will generate a few tokens sequentially, and then the main model will validate those candidate tokens in parallel and accept validated ones. The decoding process will be speeded up, for the latency of speculating multiple tokens by the drafter model is lower than that by the main model.
Moreover, Colossal-Inference also supports GLIDE, a modified draft model architecture that reuses key and value caches from the main model, which improves the acceptance rate and increment the speed-up ratio. Details can be found in research paper GLIDE with a CAPE - A Low-Hassle Method to Accelerate Speculative Decoding on [arXiv](https://arxiv.org/pdf/2402.02082.pdf).
Right now, Colossal-Inference offers a GLIDE model compatible with vicuna7B. You can find the fine-tuned GLIDE drafter model `cxdu/glide47m-vicuna7b` on the HuggingFace Hub: https://huggingface.co/cxdu/glide47m-vicuna7b.
## Usage
For main model, you might want to use model card `lmsys/vicuna-7b-v1.5` at [HuggingFace Hub](https://huggingface.co/lmsys/vicuna-7b-v1.5).
For regular drafter model, you might want to use model card `JackFram/llama-68m` at [HuggingFace Hub](https://huggingface.co/JackFram/llama-68m).
For the GLIDE drafter model, you could use model card `cxdu/glide47m-vicuna7b` at [HuggingFace Hub](https://huggingface.co/cxdu/glide47m-vicuna7b).
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine, GenerationConfig
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM, GlideLlamaConfig
# launch colossalai, setup distributed environment
colossalai.launch_from_torch(config={})
# main model
model_path_or_name = "REPLACE_TO_VICUNA_7B_PATH_OR_MODEL_CARD"
model = AutoModelForCausalLM.from_pretrained(model_path_or_name)
# use the same tokenizer for both the main model and the drafter model
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
tokenizer.pad_token = tokenizer.eos_token
# drafter model
drafter_model_path_or_name = "REPLACE_TO_LLAMA_68M_PATH_OR_MODEL_CARD"
drafter_model = AutoModelForCausalLM.from_pretrained(drafter_model_path_or_name)
# Initialize the inference engine
inference_config = InferenceConfig(
dtype="fp16",
max_batch_size=1,
max_input_len=256,
max_output_len=256,
prefill_ratio=1.2,
block_size=16,
max_n_spec_tokens=5,
prompt_template="vicuna",
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
# turn on speculative decoding with the drafter model
engine.enable_spec_dec(drafter_model)
prompt = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions. "
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_length=128,
num_beams=1,
do_sample=False,
)
out = engine.generate(prompts=[prompt], generation_config=generation_config)
print(out)
# use GLIDE Llama model as drafter model
drafter_model_path_or_name = "cxdu/glide47m-vicuna7b"
glide_config = GlideLlamaConfig(
intermediate_size=8192,
large_hidden_size=4096,
large_num_attention_heads=32,
num_hidden_layers=1,
)
drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name, config=glide_config)
# turn on speculative decoding with the GLIDE model
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
out = engine.generate(prompts=[prompt], generation_config=generation_config)
print(out)
```
You could run the above code by
```bash
colossalai run --nproc_per_node 1 script_name.py
```
## Benchmark
With batch size 1, testing with gsm8k and MT-Bench dataset on NVIDIA H800 80G:
| Method | Tokens/Sec |
| :--------------------------- | :--------- |
| Non-Spec-Dec | ~90 |
| Spec-Dec | ~115 |
| Spec-Dec with GLIDE Model | ~135 |

View File

@ -0,0 +1,4 @@
from .drafter import Drafter
from .struct import DrafterOutput, GlideInput
__all__ = ["Drafter", "DrafterOutput", "GlideInput"]

View File

@ -0,0 +1,121 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedTokenizer
from colossalai.utils import get_current_device
from .struct import DrafterOutput, GlideInput
class Drafter:
"""Container for the Drafter Model (Assistant Model) used in Speculative Decoding.
Args:
model (nn.Module): The drafter model.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
device (torch.device): The device for the drafter model.
"""
def __init__(
self,
model: nn.Module,
tokenizer: PreTrainedTokenizer,
device: torch.device = None,
dtype: torch.dtype = torch.float16,
):
self._tokenizer = tokenizer
self._device = device or get_current_device()
self._dtype = dtype
self._drafter_model = model.to(self._device)
self._drafter_model = model.to(self._dtype)
self._drafter_model.eval()
def get_model(self) -> nn.Module:
return self._drafter_model
@staticmethod
def trim_kv_cache(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
) -> Tuple[Tuple[torch.FloatTensor]]:
"""Trim the last `invalid_token_num` kv caches.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape
num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
invalid_token_num (int): The number of invalid tokens to trim.
"""
if past_key_values is None or invalid_token_num < 1:
return past_key_values
trimmed_past_key_values = []
for layer_idx in range(len(past_key_values)):
past_key_value = past_key_values[layer_idx]
trimmed_past_key_values.append(
(
past_key_value[0][:, :, :-invalid_token_num, :],
past_key_value[1][:, :, :-invalid_token_num, :],
)
)
past_key_values = tuple(trimmed_past_key_values)
return past_key_values
@torch.inference_mode()
def speculate(
self,
input_ids: torch.Tensor,
n_spec_tokens: int,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
glide_input: Optional[GlideInput] = None,
) -> DrafterOutput:
"""Generate n_spec_tokens tokens using the drafter model.
Args:
input_ids (torch.Tensor): Input token ids.
n_spec_tokens (int): Number of tokens to speculate.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
glide_input (Optional[GlideInput]): The packed input for glimpsing kv caches of the main model,
when using the glide model as a drafter.
"""
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
# For compatibility with transformers of versions before 4.38.0
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
logits = []
token_ids = []
kwargs = {"return_dict": True, "use_cache": True}
if glide_input:
# required only when using glide model
kwargs["glide_input"] = glide_input
for _ in range(n_spec_tokens):
# update past key values
kwargs["past_key_values"] = past_key_values
outputs = self._drafter_model(input_ids, **kwargs)
next_token_logits = outputs.logits[:, -1, :]
# NOTE Only use greedy search for speculating.
# As the drafter model usually has only a few layers with few parameters,
# introducing sampling will make the speculation unstable and lead to worse performance.
next_token_ids = torch.argmax(next_token_logits, dim=-1)
logits.append(next_token_logits)
token_ids.append(next_token_ids)
if next_token_ids.item() == self._tokenizer.eos_token_id:
# TODO(yuanheng-zhao) support bsz > 1
break
input_ids = next_token_ids[:, None]
past_key_values = outputs.past_key_values
speculated_length = len(token_ids) # For now, only support bsz 1
logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1)
out = DrafterOutput(
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
)
return out

View File

@ -0,0 +1,55 @@
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
@dataclass
class DrafterOutput:
"""
Dataclass for drafter model outputs.
Args:
speculated_length (int): Speculated length of the output sequence
It is always less than or equal to spec_num during drafter's speculation process
logits (torch.FloatTensor): Logits of the output sequence
next_tokens (torch.Tensor): Next token ids
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Past key values of the output sequence
"""
speculated_length: int = None
logits: torch.FloatTensor = None
next_tokens: torch.Tensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
def __post_init__(self):
assert self.speculated_length is not None and self.speculated_length >= 0
if self.past_key_values is not None:
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])
@dataclass
class GlideInput:
"""Dataclass for Glide Models (e.g. `colossalai/inference/modeling/models/glide_llama.py`).
Used for pack data that will be used during glimpsing KV Caches of the main model.
Args:
block_tables (torch.Tensor): [num_seqs, max_blocks_per_seq] The block table of KV Caches.
large_k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_size]
Blocked key cache of the main model
large_v_cache (torch.Tensor): Blocked value cache of the main model. It has the same shape as k cache.
sequence_lengths (torch.Tensor): [num_seqs] Sequence lengths of the current batch.
"""
block_tables: torch.Tensor = None
large_k_cache: torch.Tensor = None
large_v_cache: torch.Tensor = None
sequence_lengths: torch.Tensor = None
@property
def glimpse_ready(self):
return all(
attr is not None
for attr in [self.block_tables, self.large_k_cache, self.large_v_cache, self.sequence_lengths]
)

View File

@ -11,7 +11,7 @@ if HAS_TRITON:
from .context_attn_unpad import context_attention_unpadded from .context_attn_unpad import context_attention_unpadded
from .flash_decoding import flash_decoding_attention from .flash_decoding import flash_decoding_attention
from .fused_rotary_embedding import fused_rotary_embedding from .fused_rotary_embedding import fused_rotary_embedding
from .kvcache_copy import copy_kv_to_blocked_cache from .kvcache_copy import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding
from .rms_layernorm import rms_layernorm from .rms_layernorm import rms_layernorm
from .rotary_cache_copy import get_xine_cache from .rotary_cache_copy import get_xine_cache
@ -20,6 +20,7 @@ if HAS_TRITON:
__all__ = [ __all__ = [
"context_attention_unpadded", "context_attention_unpadded",
"flash_decoding_attention", "flash_decoding_attention",
"copy_k_to_blocked_cache",
"copy_kv_to_blocked_cache", "copy_kv_to_blocked_cache",
"softmax", "softmax",
"rms_layernorm", "rms_layernorm",

View File

@ -9,13 +9,14 @@ import triton.language as tl
# Triton 2.1.0 # Triton 2.1.0
@triton.jit @triton.jit
def _flash_decoding_fwd_kernel( def _flash_decoding_fwd_kernel(
Q, # [batch_size, head_num, q_len(1), head_dim] Q, # [batch_size * q_len, head_num, head_dim]
KCache, # [num_blocks, num_kv_heads, block_size, head_dim] KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, block_size, head_dim] VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
block_tables, # [batch_size, max_blocks_per_sequence] block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size, head_num, kv_split_num, head_dim] mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size, head_num, kv_split_num] mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
kv_seq_len, # [batch_size] kv_seq_len, # [batch_size]
q_len,
batch_size, batch_size,
stride_qt, stride_qt,
stride_qh, stride_qh,
@ -39,44 +40,39 @@ def _flash_decoding_fwd_kernel(
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr, HEAD_DIM: tl.constexpr,
): ):
cur_seq_idx = tl.program_id(0) cur_token_idx = tl.program_id(0)
cur_seq_idx = cur_token_idx // q_len
if cur_seq_idx >= batch_size: if cur_seq_idx >= batch_size:
return return
cur_token_off = (cur_token_idx % q_len) - q_len + 1
cur_head_idx = tl.program_id(1) cur_head_idx = tl.program_id(1)
block_start_kv = tl.program_id(2) # for splitting k/v block_start_kv = tl.program_id(2) # for splitting k/v
cur_kv_head_idx = cur_head_idx // KV_GROUPS
offsets_dmodel = tl.arange(0, HEAD_DIM)
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same # NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE) # TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
# and then support calculating multiple kv cache blocks on an instance # and then support calculating multiple kv cache blocks on an instance
tl.static_assert(BLOCK_KV == BLOCK_SIZE) tl.static_assert(BLOCK_KV == BLOCK_SIZE)
# get the current (kv) sequence length
# get the current (kv) sequence length from provided context lengths tensor # cur_token_off is used as a "mask" here for spec-dec during verification process
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
q = tl.load(Q + offsets_q)
# block table for the current sequence
block_table_ptr = block_tables + cur_seq_idx * stride_bts
# actually current block table current block start idx
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
cur_bt_start_idx = block_start_kv
cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
if block_start_kv * BLOCK_KV >= cur_kv_seq_len: if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
return return
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
q = tl.load(Q + offsets_q)
# block table for the current sequence
block_table_ptr = block_tables + cur_seq_idx * stride_bts
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
# cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)
cur_occupied_size = tl.where( cur_occupied_size = tl.where(
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE (block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
) )
tl.device_assert(cur_occupied_size >= 0) tl.device_assert(cur_occupied_size >= 0)
cur_kv_head_idx = cur_head_idx // KV_GROUPS
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
K_block_ptr = tl.make_block_ptr( K_block_ptr = tl.make_block_ptr(
base=KCache + offset_kvcache, base=KCache + offset_kvcache,
shape=(cur_occupied_size, HEAD_DIM), shape=(cur_occupied_size, HEAD_DIM),
@ -115,14 +111,14 @@ def _flash_decoding_fwd_kernel(
acc = acc / l acc = acc / l
offsets_mid_o = ( offsets_mid_o = (
cur_seq_idx * stride_mid_ot cur_token_idx * stride_mid_ot
+ cur_head_idx * stride_mid_oh + cur_head_idx * stride_mid_oh
+ block_start_kv * stride_mid_ob + block_start_kv * stride_mid_ob
+ offsets_dmodel * stride_mid_od + offsets_dmodel * stride_mid_od
) )
tl.store(mid_o + offsets_mid_o, acc) tl.store(mid_o + offsets_mid_o, acc)
offsets_mid_o_lse = ( offsets_mid_o_lse = (
cur_seq_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
) )
# logsumexp L^(j) = m^(j) + log(l^(j)) # logsumexp L^(j) = m^(j) + log(l^(j))
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l)) tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
@ -135,6 +131,7 @@ def _flash_decoding_fwd_reduce_kernel(
mid_o_lse, # [batch_size, head_num, kv_split_num] mid_o_lse, # [batch_size, head_num, kv_split_num]
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim] O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
kv_seq_len, kv_seq_len,
q_len,
batch_size, batch_size,
stride_mid_ot, stride_mid_ot,
stride_mid_oh, stride_mid_oh,
@ -149,12 +146,15 @@ def _flash_decoding_fwd_reduce_kernel(
BLOCK_KV: tl.constexpr, BLOCK_KV: tl.constexpr,
HEAD_DIM: tl.constexpr, HEAD_DIM: tl.constexpr,
): ):
cur_seq_idx = tl.program_id(0) cur_token_idx = tl.program_id(0)
cur_seq_idx = cur_token_idx // q_len
if cur_seq_idx >= batch_size: if cur_seq_idx >= batch_size:
return return
cur_head_idx = tl.program_id(1) cur_head_idx = tl.program_id(1)
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) # cur_token_off is used as a "mask" here for spec-dec during verification process
cur_token_off = (cur_token_idx % q_len) - q_len + 1
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_dmodel = tl.arange(0, HEAD_DIM)
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have # NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
@ -164,8 +164,8 @@ def _flash_decoding_fwd_reduce_kernel(
l = 0.0 # sum exp l = 0.0 # sum exp
acc = tl.zeros([HEAD_DIM], dtype=tl.float32) acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
offsets_mid_o = cur_seq_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
offset_mid_lse = cur_seq_idx * stride_o_lset + cur_head_idx * stride_o_lseh offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh
for block_i in range(0, kv_split_num, 1): for block_i in range(0, kv_split_num, 1):
mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob) mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob)
lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb) lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb)
@ -179,7 +179,7 @@ def _flash_decoding_fwd_reduce_kernel(
m_i = m_ij m_i = m_ij
acc = acc / l acc = acc / l
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
tl.store(O + offsets_O, acc.to(O.type.element_ty)) tl.store(O + offsets_O, acc.to(O.type.element_ty))
return return
@ -199,12 +199,14 @@ def flash_decoding_attention(
mid_output_lse: torch.Tensor = None, mid_output_lse: torch.Tensor = None,
sm_scale: int = None, sm_scale: int = None,
kv_group_num: int = 1, kv_group_num: int = 1,
q_len: int = 1,
): ):
""" """
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage. Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
Args: Args:
q (torch.Tensor): [bsz, num_heads, head_dim] q (torch.Tensor): [bsz * q_len, num_heads, head_dim]
q_len > 1 only for verification process in speculative-decoding.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
kv_seq_len (torch.Tensor): [batch_size] kv_seq_len (torch.Tensor): [batch_size]
@ -212,19 +214,25 @@ def flash_decoding_attention(
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch. max_seq_len_in_batch (int): Maximum sequence length in the batch.
output (torch.Tensor): [bsz, num_heads * head_dim] output (torch.Tensor): [bsz, num_heads * head_dim]
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim] mid_output (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num] q_len > 1 only for verification process in speculative-decoding.
mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num]
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`. Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
q_len > 1 only for verification process in speculative-decoding.
block_size (int): Size of each block in the blocked key/value cache. block_size (int): Size of each block in the blocked key/value cache.
num_kv_group (int, optional): Number of key/value groups. Defaults to 1. num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
Defaults to 1.
Returns: Returns:
Output tensor with shape [bsz, num_heads * head_dim] Output tensor with shape [bsz * q_len, num_heads * head_dim]
""" """
q = q.squeeze() if q.dim() == 4 else q q = q.squeeze() if q.dim() == 4 else q
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}" assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
bsz, num_heads, head_dim = q.shape n_tokens, num_heads, head_dim = q.shape
assert n_tokens % q_len == 0, "Invalid q_len"
bsz = n_tokens // q_len
assert head_dim in {32, 64, 128, 256} assert head_dim in {32, 64, 128, 256}
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, ( assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
@ -247,22 +255,31 @@ def flash_decoding_attention(
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
# For compatibility (TODO revise modeling in future) # For compatibility (TODO revise modeling in future)
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
mid_output = (
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device) if mid_output is None:
if mid_output is None mid_output = torch.empty(
else mid_output (bsz * q_len, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device
) )
mid_output_lse = ( if mid_output_lse is None:
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device) mid_output_lse = torch.empty((bsz * q_len, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
if mid_output_lse is None if output is None:
else mid_output_lse # A hack to prevent `view` operation in modeling
) output = torch.empty((bsz * q_len, num_heads * head_dim), dtype=q.dtype, device=q.device)
assert (
mid_output.size(2) == mid_output_lse.size(2) >= kv_max_split_num
), "Incompatible kv split number of intermediate output tensors"
assert (
mid_output.size(0) == mid_output_lse.size(0) >= output.size(0) == n_tokens
), f"Incompatible first dimension of output tensors"
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV)) grid = (
output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output triton.next_power_of_2(bsz * q_len),
num_heads,
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
)
_flash_decoding_fwd_kernel[grid]( _flash_decoding_fwd_kernel[grid](
q, q,
k_cache, k_cache,
@ -271,6 +288,7 @@ def flash_decoding_attention(
mid_output, mid_output,
mid_output_lse, mid_output_lse,
kv_seq_len, kv_seq_len,
q_len,
bsz, bsz,
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
@ -295,13 +313,13 @@ def flash_decoding_attention(
HEAD_DIM=head_dim, HEAD_DIM=head_dim,
) )
grid = (triton.next_power_of_2(bsz), num_heads) grid = (triton.next_power_of_2(bsz * q_len), num_heads)
_flash_decoding_fwd_reduce_kernel[grid]( _flash_decoding_fwd_reduce_kernel[grid](
mid_output, mid_output,
mid_output_lse, mid_output_lse,
output, output,
kv_seq_len, kv_seq_len,
q_len,
bsz, bsz,
mid_output.stride(0), mid_output.stride(0),
mid_output.stride(1), mid_output.stride(1),

View File

@ -3,6 +3,50 @@ import triton
import triton.language as tl import triton.language as tl
# Triton 2.1.0
@triton.jit
def _copy_to_kcache_seqlen_n_kernel(
KV, # K or V
KVCache, # KCache or VCache
BLOCK_TABLES,
context_lengths,
stride_kt,
stride_kh,
stride_kd,
stride_cacheb,
stride_cacheh,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
block_size,
n,
HEAD_DIM: tl.constexpr,
):
cur_token_idx = tl.program_id(0)
cur_seq_idx = cur_token_idx // n
cur_token_shift = cur_token_idx - (n * (cur_seq_idx + 1))
# cur_token_shift = cur_token_idx - n * cur_seq_idx
cur_kv_head_idx = tl.program_id(1)
past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_token_shift
last_bt_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offset_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv)
offsets_kvcache = (
block_id * stride_cacheb
+ cur_kv_head_idx * stride_cacheh
+ offset_last_block * stride_cachebs
+ offsets_dmodel * stride_cached
)
tl.store(KVCache + offsets_kvcache, kv)
return
# Triton 2.1.0 # Triton 2.1.0
@triton.jit @triton.jit
def _copy_to_kvcache_seqlen1_kernel( def _copy_to_kvcache_seqlen1_kernel(
@ -40,10 +84,11 @@ def _copy_to_kvcache_seqlen1_kernel(
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offsets_in_last_block = past_kv_seq_len % block_size offsets_in_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel * stride_vd
k = tl.load(K + offsets_kv) k = tl.load(K + offsets_k)
v = tl.load(V + offsets_kv) v = tl.load(V + offsets_v)
offsets_kcache = ( offsets_kcache = (
block_id * stride_cachekb block_id * stride_cachekb
@ -63,6 +108,64 @@ def _copy_to_kvcache_seqlen1_kernel(
return return
def copy_k_to_blocked_cache(
k: torch.Tensor, k_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, n: int = 1
):
"""
Copy keys or values to the blocked key/value cache during decoding stage.
Args:
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
[bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
n (int): Number of tokens to copy for each sequence. Default to 1.
"""
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
k = k.reshape(-1, k.size(-2), k.size(-1)) if k.dim() == 4 else k
assert k.dim() == 3, f"Invalid k dim {k.dim()}"
bsz, num_kv_heads, head_dim = k.shape
# NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim]
if n > 1:
assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied"
bsz = bsz // n
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}"
)
# Modify if the shape of kv cahce is changed.
block_size = k_cache.size(-2)
num_warps = 8 if head_dim > 128 else 4
grid = (bsz * n, num_kv_heads)
_copy_to_kcache_seqlen_n_kernel[grid](
k,
k_cache,
block_tables,
kv_lengths,
k.stride(0),
k.stride(1),
k.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
block_size,
n=n,
HEAD_DIM=head_dim,
num_warps=num_warps,
)
def copy_kv_to_blocked_cache( def copy_kv_to_blocked_cache(
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,

View File

@ -0,0 +1,71 @@
import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
from colossalai.inference.spec.drafter import Drafter
from colossalai.utils import get_current_device
NUM_LAYERS = 1
MAX_LEN = 100
SPEC_NUM = 5
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
def test_drafter(spec_num: int):
torch.manual_seed(123)
device = get_current_device()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
drafter = Drafter(drafter_model, tokenizer, device=device)
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
out = drafter.speculate(input_ids, spec_num)
past_kv_length = input_ids.size(1) + spec_num - 1
assert out.speculated_length == spec_num
assert out.next_tokens.shape == (spec_num,)
assert out.logits.shape == (spec_num, len(tokenizer))
assert out.past_key_values[0][0].size(2) == past_kv_length
reject_num = max(0, spec_num - 1)
trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num)
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
def test_spec_dec():
spec_num = SPEC_NUM
device = get_current_device()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.eos_token
# Dummy config for Glide Model
glide_config = GlideLlamaConfig(
intermediate_size=8192,
large_hidden_size=4096,
large_num_attention_heads=32,
num_hidden_layers=NUM_LAYERS,
)
drafter_model = GlideLlamaForCausalLM(glide_config)
assert hasattr(drafter_model, "model")
assert hasattr(drafter_model.model, "layers")
for _, layer in enumerate(drafter_model.model.layers):
assert hasattr(layer, "cross_attn")
# Init the Drafter by providing the sharded drafter model
drafter = Drafter(drafter_model, tokenizer, device=device, dtype=torch.float16)
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
out = drafter.speculate(input_ids, spec_num, past_key_values=None)
if __name__ == "__main__":
test_drafter(spec_num=SPEC_NUM)
test_spec_dec()

View File

@ -9,6 +9,7 @@ import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -80,9 +81,81 @@ def check_output_consistency(prompt_template):
FDIntermTensors._instances = {} FDIntermTensors._instances = {}
@parameterize("num_layers", [1])
@parameterize("max_length", [100])
def check_spec_dec(num_layers, max_length):
torch.manual_seed(123)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# Dummy configs for testing
toy_config = LlamaConfig(num_hidden_layers=num_layers)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
large_config = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=8,
num_key_value_heads=32,
max_position_embeddings=2048,
)
large_config.pad_token_id = tokenizer.eos_token_id
main_model = LlamaForCausalLM(large_config)
inference_config = InferenceConfig(
dtype="fp16",
micro_batch_size=1,
max_batch_size=1,
max_input_len=128,
max_output_len=128,
prefill_ratio=1.2,
block_size=16,
)
engine = InferenceEngine(main_model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.disable_spec_dec()
engine.clear_spec_dec()
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
# test GLIDE model
glide_config = GlideLlamaConfig(
intermediate_size=8192,
large_hidden_size=4096,
large_num_attention_heads=32,
num_hidden_layers=num_layers,
)
glide_model = GlideLlamaForCausalLM(glide_config)
engine.enable_spec_dec(glide_model, use_glide_drafter=True)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.clear_spec_dec()
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency() check_output_consistency()
check_spec_dec()
@pytest.mark.dist @pytest.mark.dist

View File

@ -19,12 +19,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim) return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, device="cuda"): def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=device) assert q_len <= kv_len
causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1)
padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
for i in range(bsz): for i in range(bsz):
cur_seq_len = kv_lengths[i].item() cur_seq_len = kv_lengths[i].item()
assert cur_seq_len <= kv_seq_len assert cur_seq_len <= kv_len
padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf") padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")
padding_mask[:, :, -q_len:, -q_len:] += causal_mask
return padding_mask return padding_mask
@ -33,12 +40,12 @@ def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, de
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350 # https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
def torch_attn_ref( def torch_attn_ref(
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim] q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
k: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim] k: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
v: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim] v: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len] attention_mask: torch.Tensor, # [bsz, 1, q_len, kv_len]
bsz: int, bsz: int,
seq_len: int, q_len: int,
kv_seq_len: int, kv_len: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
head_dim: int, head_dim: int,
@ -54,22 +61,24 @@ def torch_attn_ref(
qk = torch.matmul(q, k.transpose(2, 3)) qk = torch.matmul(q, k.transpose(2, 3))
attn_scores = qk / (head_dim**0.5) attn_scores = qk / (head_dim**0.5)
assert attn_scores.shape == (bsz, num_heads, seq_len, kv_seq_len), "Invalid shape of attention scores"
# for left-side padding
if attention_mask.size() != (bsz, 1, seq_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, seq_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_scores = attn_scores + attention_mask assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}"
)
attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype) attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
out = torch.matmul(attn_weights, v) out = torch.matmul(attn_weights, v)
if out.size() != (bsz, num_heads, seq_len, head_dim): if out.size() != (bsz, num_heads, q_len, head_dim):
raise ValueError( raise ValueError(
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}" f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" f" {out.size()}"
) )
out = out.transpose(1, 2).contiguous() out = out.transpose(1, 2).contiguous()
out = out.squeeze(1) out = out.view(-1, out.size(-2), out.size(-1))
# out [bsz * q_len, num_heads, head_dim]
return out return out

View File

@ -6,8 +6,8 @@ from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import ( from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded, convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2, generate_caches_and_block_tables_v2,
prepare_padding_mask,
torch_attn_ref, torch_attn_ref,
) )
@ -21,7 +21,6 @@ except ImportError:
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
Q_LEN = 1
HEAD_DIM = 128 HEAD_DIM = 128
@ -64,6 +63,7 @@ def prepare_data(
@pytest.mark.parametrize("num_attn_heads", [16]) @pytest.mark.parametrize("num_attn_heads", [16])
@pytest.mark.parametrize("kv_group_num", [1, 2, 16]) @pytest.mark.parametrize("kv_group_num", [1, 2, 16])
@pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("q_len", [1, 5])
def test_flash_decoding( def test_flash_decoding(
bsz: int, bsz: int,
block_size: int, block_size: int,
@ -71,6 +71,7 @@ def test_flash_decoding(
num_attn_heads: int, num_attn_heads: int,
kv_group_num: int, kv_group_num: int,
same_context_len: bool, same_context_len: bool,
q_len: int,
): ):
torch.manual_seed(123) torch.manual_seed(123)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -82,47 +83,57 @@ def test_flash_decoding(
max_seq_len = block_size * max_num_blocks_per_seq max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float16 dtype = torch.float16
device = get_current_device() device = get_current_device()
q, k_unpad, v_unpad, kv_lengths = prepare_data(
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, q_len, max_seq_len, dtype, device
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
) )
# The maximum sequence length in the batch (if context lengths randomly generated)
max_kv_len_in_b = kv_lengths.max().item()
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
out_torch = torch_attn_ref(
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
) )
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)
# The maximum sequence length in the batch (if context lengths randomly generated)
max_seq_len_in_b = kv_seq_lengths.max().item()
# The maximum block length splitted on kv should be the kv cache block size # The maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device) output = torch.empty((bsz * q_len, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
mid_output = torch.empty( mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device size=(bsz * q_len, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
mid_output_lse = torch.empty(
size=(bsz * q_len, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device
) )
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
sm_scale = 1.0 / (HEAD_DIM**0.5) sm_scale = 1.0 / (HEAD_DIM**0.5)
# Here we use different methods to hide the q_len dimension,
# refer to attention forward function in modeling.
if q_len > 1:
q = q.transpose(1, 2).contiguous() # [bsz, q_len, num_heads, head_dim]
q = q.view(-1, q.size(-2), q.size(-1)) # [bsz * q_len, num_heads, head_dim]
else:
q = q.squeeze(2)
assert q.shape == (bsz * q_len, num_attn_heads, HEAD_DIM)
out_triton = flash_decoding_attention( out_triton = flash_decoding_attention(
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1), q,
# refer to attention forward in modeling.
q.squeeze(2),
k_cache, k_cache,
v_cache, v_cache,
kv_seq_lengths, kv_lengths,
block_tables, block_tables,
block_size, block_size,
max_seq_len_in_b, max_kv_len_in_b,
output, output,
mid_output, mid_output,
mid_output_lse, mid_output_lse,
sm_scale=sm_scale, sm_scale=sm_scale,
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
) # [bsz, 1, num_heads, head_dim] q_len=q_len,
) # [bsz * q_len, num_heads, head_dim]
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, bsz, max_seq_len_in_b)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, bsz, max_seq_len_in_b)
torch_padding_mask = prepare_padding_mask(kv_seq_lengths, bsz, max_seq_len_in_b, q.device)
out_torch = torch_attn_ref(
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
)
assert out_torch.shape == out_triton.shape assert out_torch.shape == out_triton.shape
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4) assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)

View File

@ -2,7 +2,7 @@ import pytest
import torch import torch
from packaging import version from packaging import version
from colossalai.kernel.triton import copy_kv_to_blocked_cache from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
@ -16,7 +16,7 @@ except ImportError:
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
HEAD_DIM = 128 HEAD_DIM = 32
def prepare_data( def prepare_data(
@ -27,15 +27,16 @@ def prepare_data(
max_num_blocks_per_seq, max_num_blocks_per_seq,
same_context_len, same_context_len,
max_seq_len, max_seq_len,
device, n=1,
device="cuda",
dtype=torch.float16, dtype=torch.float16,
): ):
# past_kv_seq_lengths in this test records the previous kv seq len assert max_seq_len > n, "max_seq_len must be greater than n"
# (not incorporating the current input whose seq len is 1)
past_kv_seq_lengths = ( past_kv_seq_lengths = (
torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device) torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
if same_context_len if same_context_len
else torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device) else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
) )
num_tokens = torch.sum(past_kv_seq_lengths).item() num_tokens = torch.sum(past_kv_seq_lengths).item()
@ -48,14 +49,14 @@ def prepare_data(
) )
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device) new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device) new_v = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables # mock allocating blocks for the new k/v and update block tables
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) for _ in range(n):
# kv seq len = past kv seq len + seq len (1 during decoding stage) mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
kv_seq_lengths = past_kv_seq_lengths + 1 past_kv_seq_lengths += 1
return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables return new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@ -64,12 +65,9 @@ def prepare_data(
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) @pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16]) @pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("n_tokens", [1, 5])
def test_copy_kv_to_caches( def test_copy_kv_to_caches(
bsz: int, bsz: int, block_size: int, max_num_blocks_per_seq: int, num_kv_heads: int, same_context_len: bool, n_tokens: int
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
): ):
torch.manual_seed(123) torch.manual_seed(123)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -88,25 +86,49 @@ def test_copy_kv_to_caches(
max_num_blocks_per_seq, max_num_blocks_per_seq,
same_context_len, same_context_len,
max_seq_len, max_seq_len,
n_tokens,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
# k_cache_torch = k_cache.clone().detach() k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1))
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding") v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1))
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables) k_cache_copy = k_cache.detach().clone()
past_kv_seq_lengths = kv_seq_lengths - n_tokens
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_lengths // block_size]
offsets_in_block = past_kv_seq_lengths % block_size
past_kv_seq_len = kv_seq_lengths - 1 # Copy k (or v) to k (or v) cache
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] copy_k_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens)
offsets_in_block = past_kv_seq_len % block_size # Reshape target k from k cache to compare if matching with original tensor
k_target = k_cache[target_block_ids, :, offsets_in_block, :] # Mainly to handle cases of n_tokens > 1
k_source = new_k.squeeze() k_target = []
v_target = v_cache[target_block_ids, :, offsets_in_block, :] for i in range(bsz):
v_source = new_v.squeeze() block_table = block_tables[i]
curr_kv_len = past_kv_seq_lengths[i].item()
offset = offsets_in_block[i].item()
tokens_left = n_tokens
while tokens_left > 0:
tokens_to_fill = min(block_size - offset, tokens_left)
curr_block_id = block_table[curr_kv_len // block_size]
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
curr_kv_len += tokens_to_fill
tokens_left -= tokens_to_fill
offset = 0
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
assert k_target.shape == k_source.shape assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source) assert torch.equal(k_target, k_source)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source) if n_tokens == 1:
# Copy k and v to k/v caches
k_cache = k_cache_copy
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
k_target = k_cache_copy[target_block_ids, :, offsets_in_block, :]
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if __name__ == "__main__": if __name__ == "__main__":