mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-01 03:45:27 +00:00
[Inference/Spec-Dec] Merge pull request #5565 from hpcaitech/feat/speculative-decoding
Add Speculative Decoding and GLIDE Spec-Dec
This commit is contained in:
commit
25928d8496
@ -133,7 +133,7 @@ We offer 3 main sampling strategies now (i.e. `greedy sample`, `multinomial samp
|
|||||||
|
|
||||||
| Model | KV Cache | Paged Attention | Kernels | Tensor Parallelism | Speculative Decoding |
|
| Model | KV Cache | Paged Attention | Kernels | Tensor Parallelism | Speculative Decoding |
|
||||||
| - | - | - | - | - | - |
|
| - | - | - | - | - | - |
|
||||||
| Llama | ✅ | ✅ | ✅ | 🔜 | 🔜 |
|
| Llama | ✅ | ✅ | ✅ | 🔜 | ✅ |
|
||||||
|
|
||||||
|
|
||||||
Notations:
|
Notations:
|
||||||
@ -148,7 +148,7 @@ Notations:
|
|||||||
- [x] High-Performance Kernels
|
- [x] High-Performance Kernels
|
||||||
- [x] Llama Modelling
|
- [x] Llama Modelling
|
||||||
- [x] User Documentation
|
- [x] User Documentation
|
||||||
- [ ] Speculative Decoding
|
- [x] Speculative Decoding
|
||||||
- [ ] Tensor Parallelism
|
- [ ] Tensor Parallelism
|
||||||
- [ ] Beam Search
|
- [ ] Beam Search
|
||||||
- [ ] Early stopping
|
- [ ] Early stopping
|
||||||
|
@ -42,6 +42,9 @@ class BatchBucket:
|
|||||||
self.device = device or get_current_device()
|
self.device = device or get_current_device()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self._use_spec_dec = False
|
||||||
|
self._num_tokens_to_verify = None
|
||||||
|
|
||||||
self._current_batch_size = 0
|
self._current_batch_size = 0
|
||||||
self._sequences_dict = dict()
|
self._sequences_dict = dict()
|
||||||
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
|
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
|
||||||
@ -88,6 +91,27 @@ class BatchBucket:
|
|||||||
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
|
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_spec_dec(self) -> bool:
|
||||||
|
return self._use_spec_dec
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_tokens_to_verify(self) -> int:
|
||||||
|
return self._num_tokens_to_verify
|
||||||
|
|
||||||
|
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
|
||||||
|
"""Set batch bucket to use speculatvie decoding.
|
||||||
|
This will notify the adjust the lengths of inputs during modeling,
|
||||||
|
and let the main model verifies tokens in parallel.
|
||||||
|
"""
|
||||||
|
self._use_spec_dec = True
|
||||||
|
self._num_tokens_to_verify = num_tokens_to_verify
|
||||||
|
|
||||||
|
def reset_use_spec_dec(self) -> None:
|
||||||
|
"""Reset the usage of speculative decoding for the batch bucket"""
|
||||||
|
self._use_spec_dec = False
|
||||||
|
self._num_tokens_to_verify = None
|
||||||
|
|
||||||
def _make_compact(self) -> None:
|
def _make_compact(self) -> None:
|
||||||
# Clean and Compress the batch based on its sequences dict.
|
# Clean and Compress the batch based on its sequences dict.
|
||||||
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
|
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
|
||||||
@ -347,6 +371,23 @@ class BatchBucket:
|
|||||||
seq.check_finish()
|
seq.check_finish()
|
||||||
self._sequence_lengths[: self.current_batch_size] += 1
|
self._sequence_lengths[: self.current_batch_size] += 1
|
||||||
|
|
||||||
|
def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
|
||||||
|
"""Revoke the last n output tokens of the sequences in the batch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_tokens (int): The number of output tokens to revoke from each sequence.
|
||||||
|
It does not count in the context tokens (input tokens).
|
||||||
|
n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1.
|
||||||
|
For now, speculative decoding only supports batch size 1.
|
||||||
|
"""
|
||||||
|
if n_tokens >= 1:
|
||||||
|
seqs_iter = iter(self._sequences_dict.items())
|
||||||
|
for _ in range(n_seqs):
|
||||||
|
seq_id, seq = next(seqs_iter)
|
||||||
|
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
|
||||||
|
seq.output_token_id = seq.output_token_id[:-n_tokens]
|
||||||
|
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
|
||||||
|
|
||||||
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
|
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
|
||||||
"""Clear all the sequences in the batch.
|
"""Clear all the sequences in the batch.
|
||||||
|
|
||||||
@ -401,6 +442,21 @@ class BatchBucket:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor:
|
||||||
|
# Used for main model verification in **Decoding Stage**
|
||||||
|
# `n` is the number of tokens to be verified,
|
||||||
|
# and so that prepare the last `n` tokens of each sequence as the inputs
|
||||||
|
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
||||||
|
assert all(
|
||||||
|
seq.output_len >= n for seq in self._sequences_dict.values()
|
||||||
|
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
|
||||||
|
out_li = []
|
||||||
|
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq: Sequence = self._sequences_dict[seq_id]
|
||||||
|
out_li.extend(seq.output_token_id[-n:])
|
||||||
|
return torch.tensor(out_li, dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
# For compatibility
|
# For compatibility
|
||||||
def get_1D_inputs(self) -> torch.Tensor:
|
def get_1D_inputs(self) -> torch.Tensor:
|
||||||
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
||||||
@ -411,8 +467,6 @@ class BatchBucket:
|
|||||||
seq.output_len == 0 for seq in self._sequences_dict.values()
|
seq.output_len == 0 for seq in self._sequences_dict.values()
|
||||||
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
||||||
out_li = []
|
out_li = []
|
||||||
num_tokens = torch.sum(self._sequence_lengths)
|
|
||||||
out = torch.empty([num_tokens], dtype=torch.long)
|
|
||||||
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
|
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
seq: Sequence = self._sequences_dict[seq_id]
|
seq: Sequence = self._sequences_dict[seq_id]
|
||||||
@ -420,6 +474,10 @@ class BatchBucket:
|
|||||||
return torch.tensor(out_li, dtype=torch.long, device=self.device)
|
return torch.tensor(out_li, dtype=torch.long, device=self.device)
|
||||||
else:
|
else:
|
||||||
# Assume decoding stage
|
# Assume decoding stage
|
||||||
|
if self.use_spec_dec:
|
||||||
|
# For Speculative Decoding
|
||||||
|
# the number of tokens to be verified in parallel plus the correct token in the last step
|
||||||
|
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
|
||||||
assert all(
|
assert all(
|
||||||
seq.output_len > 0 for seq in self._sequences_dict.values()
|
seq.output_len > 0 for seq in self._sequences_dict.values()
|
||||||
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
||||||
|
@ -26,7 +26,7 @@ _ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
|
|||||||
|
|
||||||
_DEFAULT_PROMPT_TEMPLATES = {
|
_DEFAULT_PROMPT_TEMPLATES = {
|
||||||
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
|
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
|
||||||
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
|
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -46,6 +46,8 @@ class InputMetaData:
|
|||||||
head_dim (int, optional): Head dimension. Defaults to 32.
|
head_dim (int, optional): Head dimension. Defaults to 32.
|
||||||
high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False.
|
high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False.
|
||||||
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
|
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
|
||||||
|
use_spec_dec (bool): Indicate whether to use speculative decoding.
|
||||||
|
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
block_tables: torch.Tensor = None
|
block_tables: torch.Tensor = None
|
||||||
@ -59,9 +61,22 @@ class InputMetaData:
|
|||||||
head_dim: int = 32
|
head_dim: int = 32
|
||||||
high_precision: bool = False
|
high_precision: bool = False
|
||||||
dtype: torch.dtype = torch.float32
|
dtype: torch.dtype = torch.float32
|
||||||
|
use_spec_dec: bool = False
|
||||||
|
num_tokens_to_verify: int = 0
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})"
|
return (
|
||||||
|
f"InputMetaData(block_tables={self.block_tables}, "
|
||||||
|
f"sequence_lengths={self.sequence_lengths}, "
|
||||||
|
f"fd_inter_tensor={self.fd_inter_tensor}, "
|
||||||
|
f"batch_size={self.batch_size}, "
|
||||||
|
f"is_prompts={self.is_prompts}, "
|
||||||
|
f"use_cuda_kernel={self.use_cuda_kernel}, "
|
||||||
|
f"use_cuda_graph={self.use_cuda_graph}, "
|
||||||
|
f"kv_seq_len={self.kv_seq_len}, "
|
||||||
|
f"use_spec_dec={self.use_spec_dec}, "
|
||||||
|
f"num_tokens_to_verify={self.num_tokens_to_verify})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -84,6 +99,8 @@ class InferenceConfig:
|
|||||||
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
||||||
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
||||||
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
|
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
|
||||||
|
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||||
|
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
||||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||||
tp_size (int): Tensor parallel size, defaults to 1.
|
tp_size (int): Tensor parallel size, defaults to 1.
|
||||||
pp_size (int): Pipeline parallel size, defaults to 1.
|
pp_size (int): Pipeline parallel size, defaults to 1.
|
||||||
@ -118,6 +135,10 @@ class InferenceConfig:
|
|||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
min_p: Optional[float] = None
|
min_p: Optional[float] = None
|
||||||
|
|
||||||
|
# speculative decoding configs
|
||||||
|
max_n_spec_tokens: int = 5
|
||||||
|
glimpse_large_kv: bool = False
|
||||||
|
|
||||||
# paged attention configs
|
# paged attention configs
|
||||||
block_size: int = 16
|
block_size: int = 16
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
|||||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||||
from colossalai.inference.modeling.policy import model_policy_map
|
from colossalai.inference.modeling.policy import model_policy_map
|
||||||
|
from colossalai.inference.spec import Drafter, GlideInput
|
||||||
from colossalai.inference.struct import Sequence
|
from colossalai.inference.struct import Sequence
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
@ -52,19 +53,27 @@ class InferenceEngine:
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
model_policy: Policy = None,
|
model_policy: Policy = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert inference_config, "Please provide inference_config."
|
|
||||||
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
|
|
||||||
self.inference_config = inference_config
|
self.inference_config = inference_config
|
||||||
self.model_config = model.config
|
self.model_config = model.config
|
||||||
|
self.model = model
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
self.dtype = inference_config.dtype
|
self.dtype = inference_config.dtype
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
|
||||||
self.high_precision = inference_config.high_precision
|
self.high_precision = inference_config.high_precision
|
||||||
model = model.eval()
|
self._verify_args()
|
||||||
model = model.cuda()
|
|
||||||
model.to(self.dtype)
|
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||||
|
model.eval()
|
||||||
|
model = model.to(self.dtype)
|
||||||
|
model = model.to(self.device)
|
||||||
|
|
||||||
|
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||||
|
self.use_spec_dec = False
|
||||||
|
self.drafter_model = None
|
||||||
|
self.drafter = None
|
||||||
|
self.use_glide = False
|
||||||
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||||
|
|
||||||
if model_policy is None:
|
if model_policy is None:
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
@ -174,21 +183,18 @@ class InferenceEngine:
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||||
|
|
||||||
def _verify_config(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
"""
|
"""Verify the input args"""
|
||||||
Verify the input config
|
if not isinstance(self.inference_config, InferenceConfig):
|
||||||
"""
|
raise TypeError("Invalid type of inference config provided.")
|
||||||
if not isinstance(self.model, nn.Module):
|
if not isinstance(self.model, nn.Module):
|
||||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||||
self.tokenizer, PreTrainedTokenizer
|
|
||||||
):
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||||
)
|
)
|
||||||
assert (
|
if self.model.__class__.__name__ not in _supported_models:
|
||||||
self.model.__class__.__name__ in _supported_models
|
raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")
|
||||||
), f"Model {self.model.__class__.__name__} is not supported."
|
|
||||||
|
|
||||||
def _shardformer(
|
def _shardformer(
|
||||||
self,
|
self,
|
||||||
@ -224,6 +230,199 @@ class InferenceEngine:
|
|||||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||||
return shard_model
|
return shard_model
|
||||||
|
|
||||||
|
def enable_spec_dec(
|
||||||
|
self,
|
||||||
|
drafter_model: nn.Module = None,
|
||||||
|
n_spec_tokens: int = None,
|
||||||
|
use_glide_drafter: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
||||||
|
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||||
|
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||||
|
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||||
|
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
||||||
|
If True, the drafter model will be replaced by a glide model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
engine = InferenceEngine(model, tokenizer, inference_config)
|
||||||
|
|
||||||
|
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||||
|
engine.generate(...) # Speculative Decoding
|
||||||
|
|
||||||
|
engine.disable_spec_dec()
|
||||||
|
engine.generate(...) # Normal generation
|
||||||
|
|
||||||
|
engine.enable_spec_dec()
|
||||||
|
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
||||||
|
engine.clear_spec_dec()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if drafter_model is None and self.drafter is None:
|
||||||
|
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||||
|
if n_spec_tokens is not None:
|
||||||
|
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
||||||
|
self.n_spec_tokens = n_spec_tokens
|
||||||
|
if drafter_model is not None:
|
||||||
|
assert isinstance(drafter_model, nn.Module)
|
||||||
|
# overwrite the drafter, if exists
|
||||||
|
self.clear_spec_dec()
|
||||||
|
self.drafter_model = drafter_model
|
||||||
|
self.drafter = Drafter(
|
||||||
|
self.drafter_model,
|
||||||
|
self.tokenizer,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if the provided drafter model is compatible with GLIDE structure
|
||||||
|
# when `use_glide_drafter` is set to True
|
||||||
|
if (
|
||||||
|
use_glide_drafter
|
||||||
|
and hasattr(drafter_model, "model")
|
||||||
|
and hasattr(drafter_model.model, "layers")
|
||||||
|
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
||||||
|
):
|
||||||
|
self.use_glide = use_glide_drafter
|
||||||
|
elif use_glide_drafter:
|
||||||
|
self.logger.warning(
|
||||||
|
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
||||||
|
f"but the provided drafter model is not compatible with GLIDE structure."
|
||||||
|
f"Falling back to use the default drafter model (non-GLIDE)."
|
||||||
|
)
|
||||||
|
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||||
|
# using speculative decoding for subsequent generations
|
||||||
|
self.use_spec_dec = True
|
||||||
|
|
||||||
|
def disable_spec_dec(self) -> None:
|
||||||
|
"""Disable using speculative decoding for subsequent generations."""
|
||||||
|
self.request_handler.unset_spec_dec_mode()
|
||||||
|
# set back to the maximum number of tokens to speculate
|
||||||
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||||
|
self.use_glide = False
|
||||||
|
self.use_spec_dec = False
|
||||||
|
|
||||||
|
def clear_spec_dec(self) -> None:
|
||||||
|
"""Clear relatable structures of speculative decoding, if exist."""
|
||||||
|
if self.use_spec_dec:
|
||||||
|
self.disable_spec_dec()
|
||||||
|
if self.drafter_model or self.drafter:
|
||||||
|
self.drafter_model = None
|
||||||
|
self.drafter = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.use_glide = False
|
||||||
|
self.use_spec_dec = False
|
||||||
|
|
||||||
|
def steps_spec_dec(self) -> List[Sequence]:
|
||||||
|
"""
|
||||||
|
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
||||||
|
with many steps of speculating by a drafter model as well as verifying by a main model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Sequence]: finished sequences generated by one step.
|
||||||
|
"""
|
||||||
|
batch = self.request_handler.schedule() # prefill batch
|
||||||
|
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||||
|
|
||||||
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||||
|
|
||||||
|
if input_meta_data.use_cuda_graph:
|
||||||
|
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||||
|
else:
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
|
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||||
|
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
||||||
|
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
|
||||||
|
next_token_ids_spec = drafter_out.next_tokens
|
||||||
|
drafter_past_key_values = drafter_out.past_key_values
|
||||||
|
|
||||||
|
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||||
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
# append new inputs to the batch, temporarily
|
||||||
|
batch.append_batch_tokens(next_tokens)
|
||||||
|
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||||
|
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||||
|
input_token_ids = batch.get_1D_inputs_spec_dec(1)
|
||||||
|
|
||||||
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# HACK Retrieve the running batch
|
||||||
|
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||||
|
batch = self.request_handler.running_bb # running batch
|
||||||
|
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||||
|
|
||||||
|
# 3. Decoding - Drafter model speculates `n` tokens
|
||||||
|
glide_input = None
|
||||||
|
if self.use_glide:
|
||||||
|
glide_input = GlideInput(
|
||||||
|
batch.get_block_table_tensor(),
|
||||||
|
self.k_cache[-1], # use kv cahces of the last layer
|
||||||
|
self.v_cache[-1],
|
||||||
|
batch.get_sequence_lengths(),
|
||||||
|
)
|
||||||
|
|
||||||
|
drafter_out = self.drafter.speculate(
|
||||||
|
input_token_ids,
|
||||||
|
self.n_spec_tokens,
|
||||||
|
drafter_past_key_values,
|
||||||
|
glide_input=glide_input,
|
||||||
|
)
|
||||||
|
next_token_ids_spec = drafter_out.next_tokens
|
||||||
|
drafter_past_key_values = drafter_out.past_key_values
|
||||||
|
drafter_spec_length = drafter_out.speculated_length
|
||||||
|
|
||||||
|
for next_token_id_spec in next_token_ids_spec:
|
||||||
|
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||||
|
cur_length = batch.seq_lengths[0].item()
|
||||||
|
if already_allocated_kv_len < cur_length:
|
||||||
|
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
||||||
|
already_allocated_kv_len = cur_length
|
||||||
|
|
||||||
|
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||||
|
if drafter_spec_length < batch.num_tokens_to_verify:
|
||||||
|
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
||||||
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||||
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
|
||||||
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
|
||||||
|
# 5. Compare and process the results
|
||||||
|
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||||
|
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||||
|
|
||||||
|
# revoke appended tokens for each Sequence in the current batch
|
||||||
|
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||||
|
|
||||||
|
# append the last correct token generated by the main model
|
||||||
|
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||||
|
|
||||||
|
# trim past key values of the drafter model
|
||||||
|
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||||
|
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# prepare inputs for the next round of speculation
|
||||||
|
n = 1 if n_matches < drafter_spec_length else 2
|
||||||
|
input_token_ids = batch.get_1D_inputs_spec_dec(n)
|
||||||
|
|
||||||
|
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||||
|
finished_sequences = self.request_handler.update()
|
||||||
|
if len(finished_sequences) > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset back the number of speculated tokens of the batch,
|
||||||
|
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
||||||
|
# by the drafter is less than the number of speculated tokens set to the engine.
|
||||||
|
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
||||||
|
|
||||||
|
return finished_sequences
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str] = None,
|
prompts: List[str] = None,
|
||||||
@ -246,7 +445,6 @@ class InferenceEngine:
|
|||||||
List[str]: Inference result returned by one generation.
|
List[str]: Inference result returned by one generation.
|
||||||
"""
|
"""
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self.generation_config = generation_config
|
|
||||||
if prompts is not None or prompts_token_ids is not None:
|
if prompts is not None or prompts_token_ids is not None:
|
||||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||||
|
|
||||||
@ -257,8 +455,13 @@ class InferenceEngine:
|
|||||||
if generation_config is not None:
|
if generation_config is not None:
|
||||||
self.generation_config = generation_config
|
self.generation_config = generation_config
|
||||||
|
|
||||||
while self.request_handler.check_unfinished_seqs():
|
if self.use_spec_dec:
|
||||||
output_seqs_list += self.step()
|
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||||
|
while self.request_handler.check_unfinished_seqs():
|
||||||
|
output_seqs_list += self.steps_spec_dec()
|
||||||
|
else:
|
||||||
|
while self.request_handler.check_unfinished_seqs():
|
||||||
|
output_seqs_list += self.step()
|
||||||
|
|
||||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||||
|
|
||||||
@ -368,18 +571,19 @@ class InferenceEngine:
|
|||||||
|
|
||||||
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
||||||
input_ids = batch.get_1D_inputs()
|
input_ids = batch.get_1D_inputs()
|
||||||
|
|
||||||
sequence_lengths = batch.get_sequence_lengths()
|
sequence_lengths = batch.get_sequence_lengths()
|
||||||
|
|
||||||
if batch.is_prompts:
|
if batch.is_prompts:
|
||||||
output_tensor = torch.zeros(
|
n_tokens = sequence_lengths.sum().item()
|
||||||
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim),
|
|
||||||
dtype=batch.dtype,
|
|
||||||
device=batch.device,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
output_tensor = torch.zeros(
|
n_tokens = batch.current_batch_size
|
||||||
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
if batch.use_spec_dec:
|
||||||
)
|
n_tokens = batch.num_tokens_to_verify + 1
|
||||||
|
assert n_tokens == input_ids.size(0)
|
||||||
|
n_tokens = n_tokens * batch.current_batch_size
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||||
|
)
|
||||||
|
|
||||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||||
use_cuda_graph = False
|
use_cuda_graph = False
|
||||||
@ -398,6 +602,8 @@ class InferenceEngine:
|
|||||||
kv_seq_len=sequence_lengths.max().item(),
|
kv_seq_len=sequence_lengths.max().item(),
|
||||||
head_dim=batch.head_dim,
|
head_dim=batch.head_dim,
|
||||||
dtype=batch.dtype,
|
dtype=batch.dtype,
|
||||||
|
use_spec_dec=batch.use_spec_dec,
|
||||||
|
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||||
)
|
)
|
||||||
|
|
||||||
return input_ids, output_tensor, input_meta_data
|
return input_ids, output_tensor, input_meta_data
|
||||||
@ -428,7 +634,8 @@ class InferenceEngine:
|
|||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
|
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
@ -134,8 +134,12 @@ class RequestHandler:
|
|||||||
if fd_inter_tensor._tensors_initialized:
|
if fd_inter_tensor._tensors_initialized:
|
||||||
fd_inter_tensor._reset()
|
fd_inter_tensor._reset()
|
||||||
|
|
||||||
|
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
|
||||||
|
max_n_tokens = self.max_batch_size
|
||||||
|
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
|
||||||
|
|
||||||
fd_inter_tensor.initialize(
|
fd_inter_tensor.initialize(
|
||||||
max_batch_size=self.max_batch_size,
|
max_batch_size=max_n_tokens,
|
||||||
num_attn_heads=model_config.num_attention_heads,
|
num_attn_heads=model_config.num_attention_heads,
|
||||||
kv_max_split_num=kv_max_split_num,
|
kv_max_split_num=kv_max_split_num,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
@ -177,6 +181,14 @@ class RequestHandler:
|
|||||||
def get_kvcache(self):
|
def get_kvcache(self):
|
||||||
return self.cache_manager.get_kv_cache()
|
return self.cache_manager.get_kv_cache()
|
||||||
|
|
||||||
|
def set_spec_dec_mode(self, n_spec_tokens: int):
|
||||||
|
self.prefill_bb.set_use_spec_dec(n_spec_tokens)
|
||||||
|
self.running_bb.set_use_spec_dec(n_spec_tokens)
|
||||||
|
|
||||||
|
def unset_spec_dec_mode(self):
|
||||||
|
self.prefill_bb.reset_use_spec_dec()
|
||||||
|
self.running_bb.reset_use_spec_dec()
|
||||||
|
|
||||||
def schedule(self):
|
def schedule(self):
|
||||||
"""
|
"""
|
||||||
The main logic of request handler.
|
The main logic of request handler.
|
||||||
@ -204,7 +216,11 @@ class RequestHandler:
|
|||||||
lst.remove(seq)
|
lst.remove(seq)
|
||||||
|
|
||||||
if self.running_list.ready_for_prefill():
|
if self.running_list.ready_for_prefill():
|
||||||
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)
|
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.prefill_bb.available_batch_size)
|
||||||
|
# overwrite the number of sequences to add to 1 if use_spec_dec is enabled
|
||||||
|
# TODO (zhaoyuanheng): support speculative decoding for batch size > 1
|
||||||
|
if self.prefill_bb.use_spec_dec:
|
||||||
|
num_seqs_to_add = 1
|
||||||
|
|
||||||
for seq in self.running_list.prefill[:num_seqs_to_add]:
|
for seq in self.running_list.prefill[:num_seqs_to_add]:
|
||||||
seq.mark_running()
|
seq.mark_running()
|
||||||
@ -230,6 +246,13 @@ class RequestHandler:
|
|||||||
|
|
||||||
return self.running_bb
|
return self.running_bb
|
||||||
|
|
||||||
|
def allocate_batch_spec_dec(self, batch: BatchBucket, n: int):
|
||||||
|
assert batch.use_spec_dec
|
||||||
|
if n > 0:
|
||||||
|
self.cache_manager.allocate_n_tokens_from_block_tables(
|
||||||
|
batch.block_tables, batch.seq_lengths, batch.current_batch_size, n=n
|
||||||
|
)
|
||||||
|
|
||||||
def add_sequence(self, req: Sequence):
|
def add_sequence(self, req: Sequence):
|
||||||
"""
|
"""
|
||||||
Add the request to waiting list.
|
Add the request to waiting list.
|
||||||
@ -282,13 +305,21 @@ class RequestHandler:
|
|||||||
|
|
||||||
return sample_tokens
|
return sample_tokens
|
||||||
|
|
||||||
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
||||||
if (
|
if (
|
||||||
sequence.output_token_id[-1] == generation_config.eos_id
|
sequence.output_token_id[-1] == generation_config.eos_token_id
|
||||||
or sequence.output_len >= generation_config.max_output_len
|
or sequence.output_len >= generation_config.max_length
|
||||||
):
|
):
|
||||||
sequence.mark_finished()
|
sequence.mark_finished()
|
||||||
|
|
||||||
|
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
|
||||||
|
for seq in batch.seqs_li:
|
||||||
|
if (
|
||||||
|
seq.output_token_id[-1] == generation_config.eos_token_id
|
||||||
|
or seq.output_len >= generation_config.max_length
|
||||||
|
):
|
||||||
|
seq.mark_finished()
|
||||||
|
|
||||||
def check_unfinished_seqs(self) -> bool:
|
def check_unfinished_seqs(self) -> bool:
|
||||||
return self._has_waiting() or not self.running_list.is_empty()
|
return self._has_waiting() or not self.running_list.is_empty()
|
||||||
|
|
||||||
@ -309,9 +340,20 @@ class RequestHandler:
|
|||||||
|
|
||||||
# sample the next tokens
|
# sample the next tokens
|
||||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||||
|
return sample_tokens
|
||||||
|
|
||||||
|
def append_next_tokens(self, sample_tokens: torch.Tensor):
|
||||||
|
assert sample_tokens.dim() == 1
|
||||||
|
n_elements = sample_tokens.size(0)
|
||||||
if not self.prefill_bb.is_empty:
|
if not self.prefill_bb.is_empty:
|
||||||
|
assert (
|
||||||
|
self.prefill_bb.current_batch_size == n_elements
|
||||||
|
), f"Incompatible size: {n_elements} tokens to append while prefill batch size {self.prefill_bb.current_batch_size}"
|
||||||
self.prefill_bb.append_batch_tokens(sample_tokens)
|
self.prefill_bb.append_batch_tokens(sample_tokens)
|
||||||
else:
|
else:
|
||||||
|
assert (
|
||||||
|
self.running_bb.current_batch_size == n_elements
|
||||||
|
), f"Incompatible size: {n_elements} tokens to append while running batch size {self.running_bb.current_batch_size}"
|
||||||
self.running_bb.append_batch_tokens(sample_tokens)
|
self.running_bb.append_batch_tokens(sample_tokens)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
@ -349,6 +349,26 @@ class KVCacheManager:
|
|||||||
|
|
||||||
return seqs_to_recycle
|
return seqs_to_recycle
|
||||||
|
|
||||||
|
def allocate_n_tokens_from_block_tables(
|
||||||
|
self,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
bsz: int,
|
||||||
|
n: int,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Allocate logical cache blocks for `n` new tokens for a batch of sequences during decoding stage."""
|
||||||
|
assert block_tables.dim() == 2
|
||||||
|
assert context_lens.dim() == 1
|
||||||
|
|
||||||
|
bsz = block_tables.size(0) if bsz is None else bsz
|
||||||
|
assert bsz == 1, "Support bsz 1 for now" # TODO support bsz > 1
|
||||||
|
|
||||||
|
seqs_to_recycle = []
|
||||||
|
for i in range(n):
|
||||||
|
seqs_to_recycle += self.allocate_tokens_from_block_tables(block_tables, context_lens - n + i + 1, bsz)
|
||||||
|
|
||||||
|
return seqs_to_recycle
|
||||||
|
|
||||||
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
|
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
|
||||||
"""Allocate space asked on a single block in the block table, specified by the provided position id,
|
"""Allocate space asked on a single block in the block table, specified by the provided position id,
|
||||||
and updates the provided block table with the allocated block.
|
and updates the provided block table with the allocated block.
|
||||||
@ -420,9 +440,7 @@ class KVCacheManager:
|
|||||||
Returns:
|
Returns:
|
||||||
The remaining space required to be allocated (in other blocks).
|
The remaining space required to be allocated (in other blocks).
|
||||||
"""
|
"""
|
||||||
assert (
|
assert block.available_space > 0, f"Found no available space left in the chosen block {block}."
|
||||||
block.available_space > 0
|
|
||||||
), "Tried to allocate some space but found no available space left in chosen block."
|
|
||||||
space_to_allocate = min(block.available_space, space_asked)
|
space_to_allocate = min(block.available_space, space_asked)
|
||||||
block.allocate(space_to_allocate)
|
block.allocate(space_to_allocate)
|
||||||
return space_asked - space_to_allocate
|
return space_asked - space_to_allocate
|
||||||
|
475
colossalai/inference/modeling/models/glide_llama.py
Normal file
475
colossalai/inference/modeling/models/glide_llama.py
Normal file
@ -0,0 +1,475 @@
|
|||||||
|
# This is modified from huggingface transformers
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/llama/modeling_llama.py
|
||||||
|
import warnings
|
||||||
|
from types import MethodType
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaAttention,
|
||||||
|
LlamaConfig,
|
||||||
|
LlamaDecoderLayer,
|
||||||
|
LlamaDynamicNTKScalingRotaryEmbedding,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaLinearScalingRotaryEmbedding,
|
||||||
|
LlamaMLP,
|
||||||
|
LlamaModel,
|
||||||
|
LlamaRMSNorm,
|
||||||
|
LlamaRotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
from colossalai.inference.spec import GlideInput
|
||||||
|
from colossalai.kernel.triton import flash_decoding_attention
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_single_rotary_pos_emb(q, cos, sin, position_ids):
|
||||||
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||||
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
|
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
|
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||||
|
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
return q_embed
|
||||||
|
|
||||||
|
|
||||||
|
def glide_llama_causal_lm_forward(
|
||||||
|
self: LlamaForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
glide_input: Optional[GlideInput] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||||
|
|
||||||
|
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
glide_input=glide_input,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=None,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def glide_llama_model_forward(
|
||||||
|
self: LlamaModel,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
glide_input: GlideInput = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
if use_cache:
|
||||||
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if self._use_flash_attention_2:
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
elif self._use_sdpa and not output_attentions:
|
||||||
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
# GlideLlamaDecoderLayer
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
glide_input=glide_input,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GlideLlamaConfig(LlamaConfig):
|
||||||
|
"""Configuration class with specific arguments used by GLIDE llama model as a drafter"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
large_hidden_size=4096,
|
||||||
|
large_num_attention_heads=32,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.large_hidden_size = large_hidden_size
|
||||||
|
self.large_num_attention_heads = large_num_attention_heads
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCrossAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config: GlideLlamaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# large model (verifier) configs
|
||||||
|
self.large_hidden_size = config.large_hidden_size
|
||||||
|
self.large_num_heads = config.large_num_attention_heads
|
||||||
|
self.large_head_dim = self.large_hidden_size // self.large_num_heads
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
|
||||||
|
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
|
||||||
|
self._init_rope()
|
||||||
|
|
||||||
|
def _init_rope(self):
|
||||||
|
if self.config.rope_scaling is None:
|
||||||
|
self.rotary_emb = LlamaRotaryEmbedding(
|
||||||
|
self.large_head_dim,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scaling_type = self.config.rope_scaling["type"]
|
||||||
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
|
if scaling_type == "linear":
|
||||||
|
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
||||||
|
self.large_head_dim,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
|
elif scaling_type == "dynamic":
|
||||||
|
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
||||||
|
self.large_head_dim,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
block_tables = glide_input.block_tables
|
||||||
|
large_k_cache = glide_input.large_k_cache
|
||||||
|
large_v_cache = glide_input.large_v_cache
|
||||||
|
sequence_lengths = glide_input.sequence_lengths
|
||||||
|
cache_block_size = large_k_cache.size(-2)
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
kv_seq_len = sequence_lengths.max().item()
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# for RoPE
|
||||||
|
cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32)
|
||||||
|
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
|
||||||
|
|
||||||
|
attn_output = flash_decoding_attention(
|
||||||
|
q=query_states,
|
||||||
|
k_cache=large_k_cache,
|
||||||
|
v_cache=large_v_cache,
|
||||||
|
kv_seq_len=sequence_lengths,
|
||||||
|
block_tables=block_tables,
|
||||||
|
block_size=cache_block_size,
|
||||||
|
max_seq_len_in_batch=kv_seq_len,
|
||||||
|
) # attn_output: [bsz * q_len, num_heads * head_dim]
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.large_hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
# A class to be used to replace LlamaDecoderLayer in a Llama Model as Drafter in speculative decoding.
|
||||||
|
# Refer to GLIDE with a CAPE https://arxiv.org/pdf/2402.02082.pdf
|
||||||
|
class GlideLlamaDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: GlideLlamaConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
|
||||||
|
self.cross_attn = LlamaCrossAttention(config=config)
|
||||||
|
self.mlp = LlamaMLP(config)
|
||||||
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(module: LlamaDecoderLayer, *args, **kwargs) -> "GlideLlamaDecoderLayer":
|
||||||
|
"""Build a GlideLlamaDecoderLayer from a native LlamaDecoderLayer"""
|
||||||
|
config: LlamaConfig = module.mlp.config # XXX
|
||||||
|
layer_idx = module.self_attn.layer_idx
|
||||||
|
glide_config = GlideLlamaConfig(**config.to_dict())
|
||||||
|
glide_decoder_layer = GlideLlamaDecoderLayer(glide_config, layer_idx=layer_idx)
|
||||||
|
|
||||||
|
return glide_decoder_layer
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
glide_input: GlideInput = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*):
|
||||||
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||||
|
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
"""
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
curr_q_len = hidden_states.size(1)
|
||||||
|
# Cross attention
|
||||||
|
if glide_input is None or not glide_input.glimpse_ready:
|
||||||
|
warnings.warn(
|
||||||
|
"Data used for glimpsing the past KV caches of the main model (verifier) is not complete. "
|
||||||
|
"Fall back to normal decoder layer modeling (drafter). "
|
||||||
|
"This might lead to incorrect results when using the Glide Models for speculative decoding."
|
||||||
|
)
|
||||||
|
elif curr_q_len == 1:
|
||||||
|
# Notice that we skip prefill stage
|
||||||
|
# always use the output of the main model as the inputs for the next round of speculation
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.cross_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
glide_input=glide_input,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class GlideLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
def __init__(self, config: GlideLlamaConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
bound_method = MethodType(glide_llama_causal_lm_forward, self)
|
||||||
|
setattr(self, "forward", bound_method)
|
||||||
|
bound_method = MethodType(glide_llama_model_forward, self.model)
|
||||||
|
model = getattr(self, "model")
|
||||||
|
setattr(model, "forward", bound_method)
|
||||||
|
replaced_layers = nn.ModuleList(
|
||||||
|
[GlideLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
setattr(model, "layers", replaced_layers)
|
@ -18,6 +18,7 @@ from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
|||||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import (
|
from colossalai.kernel.triton import (
|
||||||
context_attention_unpadded,
|
context_attention_unpadded,
|
||||||
|
copy_k_to_blocked_cache,
|
||||||
decoding_fused_rotary_embedding,
|
decoding_fused_rotary_embedding,
|
||||||
flash_decoding_attention,
|
flash_decoding_attention,
|
||||||
get_xine_cache,
|
get_xine_cache,
|
||||||
@ -84,9 +85,9 @@ def llama_model_forward(
|
|||||||
"""This function will replace the forward function of LlamaModel.
|
"""This function will replace the forward function of LlamaModel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (BatchInfo): It stores the necessary input information for this inference.
|
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
|
||||||
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
|
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
|
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||||
"""
|
"""
|
||||||
block_tables = inputmetadata.block_tables
|
block_tables = inputmetadata.block_tables
|
||||||
@ -100,8 +101,31 @@ def llama_model_forward(
|
|||||||
if batch_size >= 32 and kv_seq_len > 512:
|
if batch_size >= 32 and kv_seq_len > 512:
|
||||||
use_cuda_kernel = False
|
use_cuda_kernel = False
|
||||||
|
|
||||||
|
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
|
||||||
|
# during speculative-decoding (`q_len > 1`)
|
||||||
|
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
|
||||||
|
if inputmetadata.use_spec_dec and use_cuda_kernel:
|
||||||
|
use_cuda_kernel = False
|
||||||
|
logger.warning("CUDA kernel is disabled for speculative-decoding.")
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_tokens_ids)
|
hidden_states = self.embed_tokens(input_tokens_ids)
|
||||||
if use_cuda_kernel:
|
cu_seqlens = None
|
||||||
|
|
||||||
|
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
|
||||||
|
if inputmetadata.use_spec_dec:
|
||||||
|
# For speculative-decoding Prefill and Verifying Stage
|
||||||
|
if inputmetadata.is_prompts:
|
||||||
|
# output tensor shape is the same as normal Prefill Stage
|
||||||
|
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
|
||||||
|
else:
|
||||||
|
# the number of tokens to be verified in parallel plus the correct token in the last step
|
||||||
|
n_tokens = inputmetadata.num_tokens_to_verify + 1
|
||||||
|
assert n_tokens == hidden_states.size(0)
|
||||||
|
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
|
||||||
|
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
|
||||||
|
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||||
|
|
||||||
|
elif use_cuda_kernel:
|
||||||
if inputmetadata != torch.float32 and use_flash_attn2:
|
if inputmetadata != torch.float32 and use_flash_attn2:
|
||||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||||
|
|
||||||
@ -113,14 +137,13 @@ def llama_model_forward(
|
|||||||
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
|
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
|
||||||
)
|
)
|
||||||
cos_sin = (cos, sin)
|
cos_sin = (cos, sin)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
cu_seqlens = None
|
|
||||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
||||||
|
|
||||||
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
||||||
|
|
||||||
norm_output = torch.empty_like(hidden_states)
|
norm_output = torch.empty_like(hidden_states)
|
||||||
|
tokens_to_verify = inputmetadata.num_tokens_to_verify if inputmetadata.use_spec_dec else None
|
||||||
residual = None
|
residual = None
|
||||||
|
|
||||||
for layer_id, decoder_layer in enumerate(self.layers):
|
for layer_id, decoder_layer in enumerate(self.layers):
|
||||||
@ -131,6 +154,8 @@ def llama_model_forward(
|
|||||||
k_cache=k_caches[layer_id],
|
k_cache=k_caches[layer_id],
|
||||||
v_cache=v_caches[layer_id],
|
v_cache=v_caches[layer_id],
|
||||||
is_prompts=inputmetadata.is_prompts,
|
is_prompts=inputmetadata.is_prompts,
|
||||||
|
is_verifier=inputmetadata.use_spec_dec,
|
||||||
|
tokens_to_verify=tokens_to_verify,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
cos_sin=cos_sin,
|
cos_sin=cos_sin,
|
||||||
fd_inter_tensor=inputmetadata.fd_inter_tensor,
|
fd_inter_tensor=inputmetadata.fd_inter_tensor,
|
||||||
@ -144,9 +169,9 @@ def llama_model_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if inputmetadata.is_prompts:
|
if inputmetadata.is_prompts:
|
||||||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
seq_len_cumsum = sequence_lengths.cumsum(dim=0)
|
||||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
hidden_states = hidden_states[seq_len_cumsum - 1].contiguous()
|
||||||
residual = residual[last_token_indexs - 1].contiguous()
|
residual = residual[seq_len_cumsum - 1].contiguous()
|
||||||
norm_output = torch.empty_like(hidden_states)
|
norm_output = torch.empty_like(hidden_states)
|
||||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||||
|
|
||||||
@ -164,6 +189,8 @@ def llama_decoder_layer_forward(
|
|||||||
cos_sin: Tuple[torch.Tensor],
|
cos_sin: Tuple[torch.Tensor],
|
||||||
fd_inter_tensor: FDIntermTensors,
|
fd_inter_tensor: FDIntermTensors,
|
||||||
is_prompts: bool = True,
|
is_prompts: bool = True,
|
||||||
|
is_verifier: bool = False,
|
||||||
|
tokens_to_verify: int = None,
|
||||||
kv_seq_len: int = 0,
|
kv_seq_len: int = 0,
|
||||||
output_tensor: torch.Tensor = None,
|
output_tensor: torch.Tensor = None,
|
||||||
norm_output: torch.Tensor = None,
|
norm_output: torch.Tensor = None,
|
||||||
@ -202,10 +229,12 @@ def llama_decoder_layer_forward(
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
k_cache=k_cache,
|
k_cache=k_cache,
|
||||||
v_cache=v_cache,
|
v_cache=v_cache,
|
||||||
|
is_prompts=is_prompts,
|
||||||
|
is_verifier=is_verifier,
|
||||||
|
tokens_to_verify=tokens_to_verify,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
cos_sin=cos_sin,
|
cos_sin=cos_sin,
|
||||||
fd_inter_tensor=fd_inter_tensor,
|
fd_inter_tensor=fd_inter_tensor,
|
||||||
is_prompts=is_prompts,
|
|
||||||
kv_seq_len=kv_seq_len,
|
kv_seq_len=kv_seq_len,
|
||||||
output_tensor=output_tensor,
|
output_tensor=output_tensor,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
@ -312,6 +341,8 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
cos_sin: Tuple[torch.Tensor],
|
cos_sin: Tuple[torch.Tensor],
|
||||||
fd_inter_tensor: FDIntermTensors,
|
fd_inter_tensor: FDIntermTensors,
|
||||||
is_prompts: bool = True,
|
is_prompts: bool = True,
|
||||||
|
is_verifier: bool = False,
|
||||||
|
tokens_to_verify: int = None,
|
||||||
kv_seq_len: int = 0,
|
kv_seq_len: int = 0,
|
||||||
output_tensor: torch.Tensor = None,
|
output_tensor: torch.Tensor = None,
|
||||||
sm_scale: int = None,
|
sm_scale: int = None,
|
||||||
@ -355,7 +386,7 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||||
# flash attn 2 currently only supports FP16/BF16.
|
# flash attn 2 currently only supports FP16/BF16.
|
||||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||||
inference_ops.context_kv_cache_memcpy(
|
inference_ops.context_kv_cache_memcpy(
|
||||||
@ -391,6 +422,8 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||||
|
|
||||||
if use_cuda_kernel:
|
if use_cuda_kernel:
|
||||||
inference_ops.rotary_embedding_and_cache_copy(
|
inference_ops.rotary_embedding_and_cache_copy(
|
||||||
query_states,
|
query_states,
|
||||||
@ -405,17 +438,26 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
high_precision,
|
high_precision,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
decoding_fused_rotary_embedding(
|
if is_verifier:
|
||||||
query_states,
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
key_states,
|
copy_k_to_blocked_cache(
|
||||||
value_states,
|
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||||
cos_sin[0],
|
)
|
||||||
cos_sin[1],
|
copy_k_to_blocked_cache(
|
||||||
k_cache,
|
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||||
v_cache,
|
)
|
||||||
block_tables,
|
else:
|
||||||
sequence_lengths,
|
decoding_fused_rotary_embedding(
|
||||||
)
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cos_sin[0],
|
||||||
|
cos_sin[1],
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_tables,
|
||||||
|
sequence_lengths,
|
||||||
|
)
|
||||||
attn_output = flash_decoding_attention(
|
attn_output = flash_decoding_attention(
|
||||||
q=query_states,
|
q=query_states,
|
||||||
k_cache=k_cache,
|
k_cache=k_cache,
|
||||||
@ -428,8 +470,10 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
mid_output=fd_inter_tensor.mid_output,
|
mid_output=fd_inter_tensor.mid_output,
|
||||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
|
q_len=q_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(-1, self.hidden_size)
|
||||||
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
from .glide_llama import GlideLlamaModelPolicy
|
||||||
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
|
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
|
||||||
|
|
||||||
model_policy_map = {
|
model_policy_map = {
|
||||||
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
|
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
|
||||||
|
"glide_llama": GlideLlamaModelPolicy,
|
||||||
}
|
}
|
||||||
|
|
||||||
__all__ = ["NoPaddingLlamaModelInferPolicy", "model_polic_map"]
|
__all__ = ["NoPaddingLlamaModelInferPolicy", "GlideLlamaModelPolicy", "model_polic_map"]
|
||||||
|
45
colossalai/inference/modeling/policy/glide_llama.py
Normal file
45
colossalai/inference/modeling/policy/glide_llama.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
|
||||||
|
|
||||||
|
from colossalai.inference.modeling.models.glide_llama import (
|
||||||
|
GlideLlamaDecoderLayer,
|
||||||
|
glide_llama_causal_lm_forward,
|
||||||
|
glide_llama_model_forward,
|
||||||
|
)
|
||||||
|
from colossalai.inference.utils import init_to_get_rotary
|
||||||
|
from colossalai.shardformer.policies.base_policy import SubModuleReplacementDescription
|
||||||
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class GlideLlamaModelPolicy(LlamaForCausalLMPolicy):
|
||||||
|
def module_policy(self):
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
num_layers = self.model.config.num_hidden_layers
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix=f"layers[{i}]",
|
||||||
|
target_module=GlideLlamaDecoderLayer,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaModel,
|
||||||
|
)
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={"forward": glide_llama_model_forward},
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaModel,
|
||||||
|
)
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={"forward": glide_llama_causal_lm_forward},
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
for layer in self.model.model.layers:
|
||||||
|
init_to_get_rotary(layer.cross_attn)
|
||||||
|
return self.model
|
96
colossalai/inference/spec/README.md
Normal file
96
colossalai/inference/spec/README.md
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# Speculative Decoding
|
||||||
|
|
||||||
|
Colossal-Inference supports speculative decoding using the inference engine, with optimized kernels and cache management for the main model.
|
||||||
|
|
||||||
|
Both a drafter model (small model) and a main model (large model) will be used during speculative decoding process. The drafter model will generate a few tokens sequentially, and then the main model will validate those candidate tokens in parallel and accept validated ones. The decoding process will be speeded up, for the latency of speculating multiple tokens by the drafter model is lower than that by the main model.
|
||||||
|
|
||||||
|
Moreover, Colossal-Inference also supports GLIDE, a modified draft model architecture that reuses key and value caches from the main model, which improves the acceptance rate and increment the speed-up ratio. Details can be found in research paper GLIDE with a CAPE - A Low-Hassle Method to Accelerate Speculative Decoding on [arXiv](https://arxiv.org/pdf/2402.02082.pdf).
|
||||||
|
|
||||||
|
Right now, Colossal-Inference offers a GLIDE model compatible with vicuna7B. You can find the fine-tuned GLIDE drafter model `cxdu/glide47m-vicuna7b` on the HuggingFace Hub: https://huggingface.co/cxdu/glide47m-vicuna7b.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
For main model, you might want to use model card `lmsys/vicuna-7b-v1.5` at [HuggingFace Hub](https://huggingface.co/lmsys/vicuna-7b-v1.5).
|
||||||
|
For regular drafter model, you might want to use model card `JackFram/llama-68m` at [HuggingFace Hub](https://huggingface.co/JackFram/llama-68m).
|
||||||
|
For the GLIDE drafter model, you could use model card `cxdu/glide47m-vicuna7b` at [HuggingFace Hub](https://huggingface.co/cxdu/glide47m-vicuna7b).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference.config import InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine, GenerationConfig
|
||||||
|
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM, GlideLlamaConfig
|
||||||
|
|
||||||
|
# launch colossalai, setup distributed environment
|
||||||
|
colossalai.launch_from_torch(config={})
|
||||||
|
|
||||||
|
# main model
|
||||||
|
model_path_or_name = "REPLACE_TO_VICUNA_7B_PATH_OR_MODEL_CARD"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path_or_name)
|
||||||
|
|
||||||
|
# use the same tokenizer for both the main model and the drafter model
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
# drafter model
|
||||||
|
drafter_model_path_or_name = "REPLACE_TO_LLAMA_68M_PATH_OR_MODEL_CARD"
|
||||||
|
drafter_model = AutoModelForCausalLM.from_pretrained(drafter_model_path_or_name)
|
||||||
|
|
||||||
|
# Initialize the inference engine
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
dtype="fp16",
|
||||||
|
max_batch_size=1,
|
||||||
|
max_input_len=256,
|
||||||
|
max_output_len=256,
|
||||||
|
prefill_ratio=1.2,
|
||||||
|
block_size=16,
|
||||||
|
max_n_spec_tokens=5,
|
||||||
|
prompt_template="vicuna",
|
||||||
|
)
|
||||||
|
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
|
|
||||||
|
# turn on speculative decoding with the drafter model
|
||||||
|
engine.enable_spec_dec(drafter_model)
|
||||||
|
|
||||||
|
prompt = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions. "
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
max_length=128,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=False,
|
||||||
|
)
|
||||||
|
out = engine.generate(prompts=[prompt], generation_config=generation_config)
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
# use GLIDE Llama model as drafter model
|
||||||
|
drafter_model_path_or_name = "cxdu/glide47m-vicuna7b"
|
||||||
|
glide_config = GlideLlamaConfig(
|
||||||
|
intermediate_size=8192,
|
||||||
|
large_hidden_size=4096,
|
||||||
|
large_num_attention_heads=32,
|
||||||
|
num_hidden_layers=1,
|
||||||
|
)
|
||||||
|
drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name, config=glide_config)
|
||||||
|
|
||||||
|
# turn on speculative decoding with the GLIDE model
|
||||||
|
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
|
||||||
|
out = engine.generate(prompts=[prompt], generation_config=generation_config)
|
||||||
|
print(out)
|
||||||
|
```
|
||||||
|
|
||||||
|
You could run the above code by
|
||||||
|
```bash
|
||||||
|
colossalai run --nproc_per_node 1 script_name.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Benchmark
|
||||||
|
|
||||||
|
With batch size 1, testing with gsm8k and MT-Bench dataset on NVIDIA H800 80G:
|
||||||
|
|
||||||
|
| Method | Tokens/Sec |
|
||||||
|
| :--------------------------- | :--------- |
|
||||||
|
| Non-Spec-Dec | ~90 |
|
||||||
|
| Spec-Dec | ~115 |
|
||||||
|
| Spec-Dec with GLIDE Model | ~135 |
|
4
colossalai/inference/spec/__init__.py
Normal file
4
colossalai/inference/spec/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .drafter import Drafter
|
||||||
|
from .struct import DrafterOutput, GlideInput
|
||||||
|
|
||||||
|
__all__ = ["Drafter", "DrafterOutput", "GlideInput"]
|
121
colossalai/inference/spec/drafter.py
Normal file
121
colossalai/inference/spec/drafter.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
from .struct import DrafterOutput, GlideInput
|
||||||
|
|
||||||
|
|
||||||
|
class Drafter:
|
||||||
|
"""Container for the Drafter Model (Assistant Model) used in Speculative Decoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The drafter model.
|
||||||
|
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
|
||||||
|
device (torch.device): The device for the drafter model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
device: torch.device = None,
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
):
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._device = device or get_current_device()
|
||||||
|
self._dtype = dtype
|
||||||
|
self._drafter_model = model.to(self._device)
|
||||||
|
self._drafter_model = model.to(self._dtype)
|
||||||
|
self._drafter_model.eval()
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self._drafter_model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trim_kv_cache(
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
|
||||||
|
) -> Tuple[Tuple[torch.FloatTensor]]:
|
||||||
|
"""Trim the last `invalid_token_num` kv caches.
|
||||||
|
|
||||||
|
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape
|
||||||
|
num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
|
||||||
|
invalid_token_num (int): The number of invalid tokens to trim.
|
||||||
|
"""
|
||||||
|
if past_key_values is None or invalid_token_num < 1:
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
trimmed_past_key_values = []
|
||||||
|
for layer_idx in range(len(past_key_values)):
|
||||||
|
past_key_value = past_key_values[layer_idx]
|
||||||
|
trimmed_past_key_values.append(
|
||||||
|
(
|
||||||
|
past_key_value[0][:, :, :-invalid_token_num, :],
|
||||||
|
past_key_value[1][:, :, :-invalid_token_num, :],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
past_key_values = tuple(trimmed_past_key_values)
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def speculate(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
n_spec_tokens: int,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
glide_input: Optional[GlideInput] = None,
|
||||||
|
) -> DrafterOutput:
|
||||||
|
"""Generate n_spec_tokens tokens using the drafter model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (torch.Tensor): Input token ids.
|
||||||
|
n_spec_tokens (int): Number of tokens to speculate.
|
||||||
|
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
|
||||||
|
glide_input (Optional[GlideInput]): The packed input for glimpsing kv caches of the main model,
|
||||||
|
when using the glide model as a drafter.
|
||||||
|
"""
|
||||||
|
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
|
||||||
|
|
||||||
|
# For compatibility with transformers of versions before 4.38.0
|
||||||
|
if input_ids.dim() == 1:
|
||||||
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
logits = []
|
||||||
|
token_ids = []
|
||||||
|
|
||||||
|
kwargs = {"return_dict": True, "use_cache": True}
|
||||||
|
if glide_input:
|
||||||
|
# required only when using glide model
|
||||||
|
kwargs["glide_input"] = glide_input
|
||||||
|
|
||||||
|
for _ in range(n_spec_tokens):
|
||||||
|
# update past key values
|
||||||
|
kwargs["past_key_values"] = past_key_values
|
||||||
|
|
||||||
|
outputs = self._drafter_model(input_ids, **kwargs)
|
||||||
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
|
# NOTE Only use greedy search for speculating.
|
||||||
|
# As the drafter model usually has only a few layers with few parameters,
|
||||||
|
# introducing sampling will make the speculation unstable and lead to worse performance.
|
||||||
|
next_token_ids = torch.argmax(next_token_logits, dim=-1)
|
||||||
|
|
||||||
|
logits.append(next_token_logits)
|
||||||
|
token_ids.append(next_token_ids)
|
||||||
|
if next_token_ids.item() == self._tokenizer.eos_token_id:
|
||||||
|
# TODO(yuanheng-zhao) support bsz > 1
|
||||||
|
break
|
||||||
|
input_ids = next_token_ids[:, None]
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
speculated_length = len(token_ids) # For now, only support bsz 1
|
||||||
|
logits = torch.concat(logits, dim=0)
|
||||||
|
token_ids = torch.concat(token_ids, dim=-1)
|
||||||
|
|
||||||
|
out = DrafterOutput(
|
||||||
|
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
|
||||||
|
)
|
||||||
|
return out
|
55
colossalai/inference/spec/struct.py
Normal file
55
colossalai/inference/spec/struct.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DrafterOutput:
|
||||||
|
"""
|
||||||
|
Dataclass for drafter model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speculated_length (int): Speculated length of the output sequence
|
||||||
|
It is always less than or equal to spec_num during drafter's speculation process
|
||||||
|
logits (torch.FloatTensor): Logits of the output sequence
|
||||||
|
next_tokens (torch.Tensor): Next token ids
|
||||||
|
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Past key values of the output sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
speculated_length: int = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
next_tokens: torch.Tensor = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.speculated_length is not None and self.speculated_length >= 0
|
||||||
|
if self.past_key_values is not None:
|
||||||
|
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
|
||||||
|
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GlideInput:
|
||||||
|
"""Dataclass for Glide Models (e.g. `colossalai/inference/modeling/models/glide_llama.py`).
|
||||||
|
Used for pack data that will be used during glimpsing KV Caches of the main model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_tables (torch.Tensor): [num_seqs, max_blocks_per_seq] The block table of KV Caches.
|
||||||
|
large_k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_size]
|
||||||
|
Blocked key cache of the main model
|
||||||
|
large_v_cache (torch.Tensor): Blocked value cache of the main model. It has the same shape as k cache.
|
||||||
|
sequence_lengths (torch.Tensor): [num_seqs] Sequence lengths of the current batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
block_tables: torch.Tensor = None
|
||||||
|
large_k_cache: torch.Tensor = None
|
||||||
|
large_v_cache: torch.Tensor = None
|
||||||
|
sequence_lengths: torch.Tensor = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def glimpse_ready(self):
|
||||||
|
return all(
|
||||||
|
attr is not None
|
||||||
|
for attr in [self.block_tables, self.large_k_cache, self.large_v_cache, self.sequence_lengths]
|
||||||
|
)
|
@ -11,7 +11,7 @@ if HAS_TRITON:
|
|||||||
from .context_attn_unpad import context_attention_unpadded
|
from .context_attn_unpad import context_attention_unpadded
|
||||||
from .flash_decoding import flash_decoding_attention
|
from .flash_decoding import flash_decoding_attention
|
||||||
from .fused_rotary_embedding import fused_rotary_embedding
|
from .fused_rotary_embedding import fused_rotary_embedding
|
||||||
from .kvcache_copy import copy_kv_to_blocked_cache
|
from .kvcache_copy import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
|
||||||
from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding
|
from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding
|
||||||
from .rms_layernorm import rms_layernorm
|
from .rms_layernorm import rms_layernorm
|
||||||
from .rotary_cache_copy import get_xine_cache
|
from .rotary_cache_copy import get_xine_cache
|
||||||
@ -20,6 +20,7 @@ if HAS_TRITON:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"context_attention_unpadded",
|
"context_attention_unpadded",
|
||||||
"flash_decoding_attention",
|
"flash_decoding_attention",
|
||||||
|
"copy_k_to_blocked_cache",
|
||||||
"copy_kv_to_blocked_cache",
|
"copy_kv_to_blocked_cache",
|
||||||
"softmax",
|
"softmax",
|
||||||
"rms_layernorm",
|
"rms_layernorm",
|
||||||
|
@ -9,13 +9,14 @@ import triton.language as tl
|
|||||||
# Triton 2.1.0
|
# Triton 2.1.0
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _flash_decoding_fwd_kernel(
|
def _flash_decoding_fwd_kernel(
|
||||||
Q, # [batch_size, head_num, q_len(1), head_dim]
|
Q, # [batch_size * q_len, head_num, head_dim]
|
||||||
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
block_tables, # [batch_size, max_blocks_per_sequence]
|
block_tables, # [batch_size, max_blocks_per_sequence]
|
||||||
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
|
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
|
||||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
|
||||||
kv_seq_len, # [batch_size]
|
kv_seq_len, # [batch_size]
|
||||||
|
q_len,
|
||||||
batch_size,
|
batch_size,
|
||||||
stride_qt,
|
stride_qt,
|
||||||
stride_qh,
|
stride_qh,
|
||||||
@ -39,44 +40,39 @@ def _flash_decoding_fwd_kernel(
|
|||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
HEAD_DIM: tl.constexpr,
|
HEAD_DIM: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq_idx = tl.program_id(0)
|
cur_token_idx = tl.program_id(0)
|
||||||
|
cur_seq_idx = cur_token_idx // q_len
|
||||||
if cur_seq_idx >= batch_size:
|
if cur_seq_idx >= batch_size:
|
||||||
return
|
return
|
||||||
|
cur_token_off = (cur_token_idx % q_len) - q_len + 1
|
||||||
cur_head_idx = tl.program_id(1)
|
cur_head_idx = tl.program_id(1)
|
||||||
block_start_kv = tl.program_id(2) # for splitting k/v
|
block_start_kv = tl.program_id(2) # for splitting k/v
|
||||||
|
|
||||||
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
|
||||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
|
||||||
|
|
||||||
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
|
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
|
||||||
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
|
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
|
||||||
# and then support calculating multiple kv cache blocks on an instance
|
# and then support calculating multiple kv cache blocks on an instance
|
||||||
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
||||||
|
# get the current (kv) sequence length
|
||||||
# get the current (kv) sequence length from provided context lengths tensor
|
# cur_token_off is used as a "mask" here for spec-dec during verification process
|
||||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
|
||||||
|
|
||||||
offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
|
||||||
q = tl.load(Q + offsets_q)
|
|
||||||
|
|
||||||
# block table for the current sequence
|
|
||||||
block_table_ptr = block_tables + cur_seq_idx * stride_bts
|
|
||||||
|
|
||||||
# actually current block table current block start idx
|
|
||||||
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
|
|
||||||
cur_bt_start_idx = block_start_kv
|
|
||||||
cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
|
|
||||||
|
|
||||||
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
|
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
||||||
|
q = tl.load(Q + offsets_q)
|
||||||
|
# block table for the current sequence
|
||||||
|
block_table_ptr = block_tables + cur_seq_idx * stride_bts
|
||||||
|
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
|
||||||
|
# cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
|
||||||
|
cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)
|
||||||
cur_occupied_size = tl.where(
|
cur_occupied_size = tl.where(
|
||||||
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
|
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
|
||||||
)
|
)
|
||||||
tl.device_assert(cur_occupied_size >= 0)
|
tl.device_assert(cur_occupied_size >= 0)
|
||||||
|
|
||||||
|
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
||||||
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
||||||
|
|
||||||
K_block_ptr = tl.make_block_ptr(
|
K_block_ptr = tl.make_block_ptr(
|
||||||
base=KCache + offset_kvcache,
|
base=KCache + offset_kvcache,
|
||||||
shape=(cur_occupied_size, HEAD_DIM),
|
shape=(cur_occupied_size, HEAD_DIM),
|
||||||
@ -115,14 +111,14 @@ def _flash_decoding_fwd_kernel(
|
|||||||
acc = acc / l
|
acc = acc / l
|
||||||
|
|
||||||
offsets_mid_o = (
|
offsets_mid_o = (
|
||||||
cur_seq_idx * stride_mid_ot
|
cur_token_idx * stride_mid_ot
|
||||||
+ cur_head_idx * stride_mid_oh
|
+ cur_head_idx * stride_mid_oh
|
||||||
+ block_start_kv * stride_mid_ob
|
+ block_start_kv * stride_mid_ob
|
||||||
+ offsets_dmodel * stride_mid_od
|
+ offsets_dmodel * stride_mid_od
|
||||||
)
|
)
|
||||||
tl.store(mid_o + offsets_mid_o, acc)
|
tl.store(mid_o + offsets_mid_o, acc)
|
||||||
offsets_mid_o_lse = (
|
offsets_mid_o_lse = (
|
||||||
cur_seq_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
|
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
|
||||||
)
|
)
|
||||||
# logsumexp L^(j) = m^(j) + log(l^(j))
|
# logsumexp L^(j) = m^(j) + log(l^(j))
|
||||||
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
|
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
|
||||||
@ -135,6 +131,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
||||||
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
|
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
|
q_len,
|
||||||
batch_size,
|
batch_size,
|
||||||
stride_mid_ot,
|
stride_mid_ot,
|
||||||
stride_mid_oh,
|
stride_mid_oh,
|
||||||
@ -149,12 +146,15 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||||||
BLOCK_KV: tl.constexpr,
|
BLOCK_KV: tl.constexpr,
|
||||||
HEAD_DIM: tl.constexpr,
|
HEAD_DIM: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq_idx = tl.program_id(0)
|
cur_token_idx = tl.program_id(0)
|
||||||
|
cur_seq_idx = cur_token_idx // q_len
|
||||||
if cur_seq_idx >= batch_size:
|
if cur_seq_idx >= batch_size:
|
||||||
return
|
return
|
||||||
cur_head_idx = tl.program_id(1)
|
cur_head_idx = tl.program_id(1)
|
||||||
|
|
||||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
# cur_token_off is used as a "mask" here for spec-dec during verification process
|
||||||
|
cur_token_off = (cur_token_idx % q_len) - q_len + 1
|
||||||
|
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
|
||||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
|
|
||||||
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
|
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
|
||||||
@ -164,8 +164,8 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||||||
l = 0.0 # sum exp
|
l = 0.0 # sum exp
|
||||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||||
|
|
||||||
offsets_mid_o = cur_seq_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
|
offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
|
||||||
offset_mid_lse = cur_seq_idx * stride_o_lset + cur_head_idx * stride_o_lseh
|
offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh
|
||||||
for block_i in range(0, kv_split_num, 1):
|
for block_i in range(0, kv_split_num, 1):
|
||||||
mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob)
|
mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob)
|
||||||
lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb)
|
lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb)
|
||||||
@ -179,7 +179,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||||||
m_i = m_ij
|
m_i = m_ij
|
||||||
|
|
||||||
acc = acc / l
|
acc = acc / l
|
||||||
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
|
offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
|
||||||
tl.store(O + offsets_O, acc.to(O.type.element_ty))
|
tl.store(O + offsets_O, acc.to(O.type.element_ty))
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -199,12 +199,14 @@ def flash_decoding_attention(
|
|||||||
mid_output_lse: torch.Tensor = None,
|
mid_output_lse: torch.Tensor = None,
|
||||||
sm_scale: int = None,
|
sm_scale: int = None,
|
||||||
kv_group_num: int = 1,
|
kv_group_num: int = 1,
|
||||||
|
q_len: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q (torch.Tensor): [bsz, num_heads, head_dim]
|
q (torch.Tensor): [bsz * q_len, num_heads, head_dim]
|
||||||
|
q_len > 1 only for verification process in speculative-decoding.
|
||||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
kv_seq_len (torch.Tensor): [batch_size]
|
kv_seq_len (torch.Tensor): [batch_size]
|
||||||
@ -212,19 +214,25 @@ def flash_decoding_attention(
|
|||||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||||
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
||||||
output (torch.Tensor): [bsz, num_heads * head_dim]
|
output (torch.Tensor): [bsz, num_heads * head_dim]
|
||||||
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
mid_output (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num, head_dim]
|
||||||
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
||||||
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
q_len > 1 only for verification process in speculative-decoding.
|
||||||
|
mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num]
|
||||||
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
|
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
|
||||||
|
q_len > 1 only for verification process in speculative-decoding.
|
||||||
block_size (int): Size of each block in the blocked key/value cache.
|
block_size (int): Size of each block in the blocked key/value cache.
|
||||||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||||
|
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
|
||||||
|
Defaults to 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output tensor with shape [bsz, num_heads * head_dim]
|
Output tensor with shape [bsz * q_len, num_heads * head_dim]
|
||||||
"""
|
"""
|
||||||
q = q.squeeze() if q.dim() == 4 else q
|
q = q.squeeze() if q.dim() == 4 else q
|
||||||
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
|
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
|
||||||
bsz, num_heads, head_dim = q.shape
|
n_tokens, num_heads, head_dim = q.shape
|
||||||
|
assert n_tokens % q_len == 0, "Invalid q_len"
|
||||||
|
bsz = n_tokens // q_len
|
||||||
|
|
||||||
assert head_dim in {32, 64, 128, 256}
|
assert head_dim in {32, 64, 128, 256}
|
||||||
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
||||||
@ -247,22 +255,31 @@ def flash_decoding_attention(
|
|||||||
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
|
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
|
||||||
# For compatibility (TODO revise modeling in future)
|
# For compatibility (TODO revise modeling in future)
|
||||||
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
|
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
|
||||||
mid_output = (
|
|
||||||
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
|
if mid_output is None:
|
||||||
if mid_output is None
|
mid_output = torch.empty(
|
||||||
else mid_output
|
(bsz * q_len, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device
|
||||||
)
|
)
|
||||||
mid_output_lse = (
|
if mid_output_lse is None:
|
||||||
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
mid_output_lse = torch.empty((bsz * q_len, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||||
if mid_output_lse is None
|
if output is None:
|
||||||
else mid_output_lse
|
# A hack to prevent `view` operation in modeling
|
||||||
)
|
output = torch.empty((bsz * q_len, num_heads * head_dim), dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mid_output.size(2) == mid_output_lse.size(2) >= kv_max_split_num
|
||||||
|
), "Incompatible kv split number of intermediate output tensors"
|
||||||
|
assert (
|
||||||
|
mid_output.size(0) == mid_output_lse.size(0) >= output.size(0) == n_tokens
|
||||||
|
), f"Incompatible first dimension of output tensors"
|
||||||
|
|
||||||
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
||||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||||
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
|
grid = (
|
||||||
output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
triton.next_power_of_2(bsz * q_len),
|
||||||
|
num_heads,
|
||||||
|
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
|
||||||
|
)
|
||||||
_flash_decoding_fwd_kernel[grid](
|
_flash_decoding_fwd_kernel[grid](
|
||||||
q,
|
q,
|
||||||
k_cache,
|
k_cache,
|
||||||
@ -271,6 +288,7 @@ def flash_decoding_attention(
|
|||||||
mid_output,
|
mid_output,
|
||||||
mid_output_lse,
|
mid_output_lse,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
|
q_len,
|
||||||
bsz,
|
bsz,
|
||||||
q.stride(0),
|
q.stride(0),
|
||||||
q.stride(1),
|
q.stride(1),
|
||||||
@ -295,13 +313,13 @@ def flash_decoding_attention(
|
|||||||
HEAD_DIM=head_dim,
|
HEAD_DIM=head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
grid = (triton.next_power_of_2(bsz), num_heads)
|
grid = (triton.next_power_of_2(bsz * q_len), num_heads)
|
||||||
|
|
||||||
_flash_decoding_fwd_reduce_kernel[grid](
|
_flash_decoding_fwd_reduce_kernel[grid](
|
||||||
mid_output,
|
mid_output,
|
||||||
mid_output_lse,
|
mid_output_lse,
|
||||||
output,
|
output,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
|
q_len,
|
||||||
bsz,
|
bsz,
|
||||||
mid_output.stride(0),
|
mid_output.stride(0),
|
||||||
mid_output.stride(1),
|
mid_output.stride(1),
|
||||||
|
@ -3,6 +3,50 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
# Triton 2.1.0
|
||||||
|
@triton.jit
|
||||||
|
def _copy_to_kcache_seqlen_n_kernel(
|
||||||
|
KV, # K or V
|
||||||
|
KVCache, # KCache or VCache
|
||||||
|
BLOCK_TABLES,
|
||||||
|
context_lengths,
|
||||||
|
stride_kt,
|
||||||
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_cacheb,
|
||||||
|
stride_cacheh,
|
||||||
|
stride_cachebs,
|
||||||
|
stride_cached,
|
||||||
|
stride_bts,
|
||||||
|
stride_btb,
|
||||||
|
block_size,
|
||||||
|
n,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_token_idx = tl.program_id(0)
|
||||||
|
cur_seq_idx = cur_token_idx // n
|
||||||
|
cur_token_shift = cur_token_idx - (n * (cur_seq_idx + 1))
|
||||||
|
# cur_token_shift = cur_token_idx - n * cur_seq_idx
|
||||||
|
cur_kv_head_idx = tl.program_id(1)
|
||||||
|
|
||||||
|
past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_token_shift
|
||||||
|
last_bt_block_idx = past_kv_seq_len // block_size
|
||||||
|
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||||
|
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||||
|
offset_last_block = past_kv_seq_len % block_size
|
||||||
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
|
offsets_kv = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||||
|
kv = tl.load(KV + offsets_kv)
|
||||||
|
offsets_kvcache = (
|
||||||
|
block_id * stride_cacheb
|
||||||
|
+ cur_kv_head_idx * stride_cacheh
|
||||||
|
+ offset_last_block * stride_cachebs
|
||||||
|
+ offsets_dmodel * stride_cached
|
||||||
|
)
|
||||||
|
tl.store(KVCache + offsets_kvcache, kv)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
# Triton 2.1.0
|
# Triton 2.1.0
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _copy_to_kvcache_seqlen1_kernel(
|
def _copy_to_kvcache_seqlen1_kernel(
|
||||||
@ -40,10 +84,11 @@ def _copy_to_kvcache_seqlen1_kernel(
|
|||||||
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||||
offsets_in_last_block = past_kv_seq_len % block_size
|
offsets_in_last_block = past_kv_seq_len % block_size
|
||||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||||
|
offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel * stride_vd
|
||||||
|
|
||||||
k = tl.load(K + offsets_kv)
|
k = tl.load(K + offsets_k)
|
||||||
v = tl.load(V + offsets_kv)
|
v = tl.load(V + offsets_v)
|
||||||
|
|
||||||
offsets_kcache = (
|
offsets_kcache = (
|
||||||
block_id * stride_cachekb
|
block_id * stride_cachekb
|
||||||
@ -63,6 +108,64 @@ def _copy_to_kvcache_seqlen1_kernel(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def copy_k_to_blocked_cache(
|
||||||
|
k: torch.Tensor, k_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, n: int = 1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||||
|
[bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n
|
||||||
|
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
|
||||||
|
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||||
|
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||||
|
n (int): Number of tokens to copy for each sequence. Default to 1.
|
||||||
|
"""
|
||||||
|
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
||||||
|
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||||
|
|
||||||
|
k = k.reshape(-1, k.size(-2), k.size(-1)) if k.dim() == 4 else k
|
||||||
|
assert k.dim() == 3, f"Invalid k dim {k.dim()}"
|
||||||
|
bsz, num_kv_heads, head_dim = k.shape
|
||||||
|
# NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim]
|
||||||
|
if n > 1:
|
||||||
|
assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied"
|
||||||
|
bsz = bsz // n
|
||||||
|
|
||||||
|
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||||
|
f"Got incompatible batch size (number of seqs):\n"
|
||||||
|
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
|
||||||
|
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modify if the shape of kv cahce is changed.
|
||||||
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
|
num_warps = 8 if head_dim > 128 else 4
|
||||||
|
|
||||||
|
grid = (bsz * n, num_kv_heads)
|
||||||
|
_copy_to_kcache_seqlen_n_kernel[grid](
|
||||||
|
k,
|
||||||
|
k_cache,
|
||||||
|
block_tables,
|
||||||
|
kv_lengths,
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
k_cache.stride(0),
|
||||||
|
k_cache.stride(1),
|
||||||
|
k_cache.stride(2),
|
||||||
|
k_cache.stride(3),
|
||||||
|
block_tables.stride(0),
|
||||||
|
block_tables.stride(1),
|
||||||
|
block_size,
|
||||||
|
n=n,
|
||||||
|
HEAD_DIM=head_dim,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def copy_kv_to_blocked_cache(
|
def copy_kv_to_blocked_cache(
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
|
71
tests/test_infer/test_drafter.py
Normal file
71
tests/test_infer/test_drafter.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
|
||||||
|
from colossalai.inference.spec.drafter import Drafter
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
NUM_LAYERS = 1
|
||||||
|
MAX_LEN = 100
|
||||||
|
SPEC_NUM = 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
|
||||||
|
def test_drafter(spec_num: int):
|
||||||
|
torch.manual_seed(123)
|
||||||
|
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
||||||
|
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||||
|
drafter_model = LlamaForCausalLM(toy_config)
|
||||||
|
drafter_model = drafter_model.eval().cuda()
|
||||||
|
|
||||||
|
drafter = Drafter(drafter_model, tokenizer, device=device)
|
||||||
|
|
||||||
|
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
|
||||||
|
out = drafter.speculate(input_ids, spec_num)
|
||||||
|
past_kv_length = input_ids.size(1) + spec_num - 1
|
||||||
|
|
||||||
|
assert out.speculated_length == spec_num
|
||||||
|
assert out.next_tokens.shape == (spec_num,)
|
||||||
|
assert out.logits.shape == (spec_num, len(tokenizer))
|
||||||
|
assert out.past_key_values[0][0].size(2) == past_kv_length
|
||||||
|
|
||||||
|
reject_num = max(0, spec_num - 1)
|
||||||
|
trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num)
|
||||||
|
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
|
||||||
|
|
||||||
|
|
||||||
|
def test_spec_dec():
|
||||||
|
spec_num = SPEC_NUM
|
||||||
|
device = get_current_device()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
# Dummy config for Glide Model
|
||||||
|
glide_config = GlideLlamaConfig(
|
||||||
|
intermediate_size=8192,
|
||||||
|
large_hidden_size=4096,
|
||||||
|
large_num_attention_heads=32,
|
||||||
|
num_hidden_layers=NUM_LAYERS,
|
||||||
|
)
|
||||||
|
drafter_model = GlideLlamaForCausalLM(glide_config)
|
||||||
|
|
||||||
|
assert hasattr(drafter_model, "model")
|
||||||
|
assert hasattr(drafter_model.model, "layers")
|
||||||
|
for _, layer in enumerate(drafter_model.model.layers):
|
||||||
|
assert hasattr(layer, "cross_attn")
|
||||||
|
|
||||||
|
# Init the Drafter by providing the sharded drafter model
|
||||||
|
drafter = Drafter(drafter_model, tokenizer, device=device, dtype=torch.float16)
|
||||||
|
|
||||||
|
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
|
||||||
|
out = drafter.speculate(input_ids, spec_num, past_key_values=None)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_drafter(spec_num=SPEC_NUM)
|
||||||
|
test_spec_dec()
|
@ -9,6 +9,7 @@ import colossalai
|
|||||||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||||
from colossalai.inference.core.engine import InferenceEngine
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
|
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
@ -80,9 +81,81 @@ def check_output_consistency(prompt_template):
|
|||||||
FDIntermTensors._instances = {}
|
FDIntermTensors._instances = {}
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("num_layers", [1])
|
||||||
|
@parameterize("max_length", [100])
|
||||||
|
def check_spec_dec(num_layers, max_length):
|
||||||
|
torch.manual_seed(123)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
# Dummy configs for testing
|
||||||
|
toy_config = LlamaConfig(num_hidden_layers=num_layers)
|
||||||
|
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||||
|
drafter_model = LlamaForCausalLM(toy_config)
|
||||||
|
drafter_model = drafter_model.eval().cuda()
|
||||||
|
large_config = LlamaConfig(
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_hidden_layers=8,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
)
|
||||||
|
large_config.pad_token_id = tokenizer.eos_token_id
|
||||||
|
main_model = LlamaForCausalLM(large_config)
|
||||||
|
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
dtype="fp16",
|
||||||
|
micro_batch_size=1,
|
||||||
|
max_batch_size=1,
|
||||||
|
max_input_len=128,
|
||||||
|
max_output_len=128,
|
||||||
|
prefill_ratio=1.2,
|
||||||
|
block_size=16,
|
||||||
|
)
|
||||||
|
engine = InferenceEngine(main_model, tokenizer, inference_config)
|
||||||
|
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||||
|
|
||||||
|
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
max_length=max_length,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
out, out_token_ids = engine.generate(
|
||||||
|
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
|
||||||
|
)
|
||||||
|
engine.disable_spec_dec()
|
||||||
|
engine.clear_spec_dec()
|
||||||
|
|
||||||
|
assert not engine.use_spec_dec
|
||||||
|
assert engine.drafter is None and engine.drafter_model is None
|
||||||
|
|
||||||
|
assert len(out) == 1
|
||||||
|
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
|
||||||
|
|
||||||
|
# test GLIDE model
|
||||||
|
glide_config = GlideLlamaConfig(
|
||||||
|
intermediate_size=8192,
|
||||||
|
large_hidden_size=4096,
|
||||||
|
large_num_attention_heads=32,
|
||||||
|
num_hidden_layers=num_layers,
|
||||||
|
)
|
||||||
|
glide_model = GlideLlamaForCausalLM(glide_config)
|
||||||
|
engine.enable_spec_dec(glide_model, use_glide_drafter=True)
|
||||||
|
|
||||||
|
out, out_token_ids = engine.generate(
|
||||||
|
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
|
||||||
|
)
|
||||||
|
engine.clear_spec_dec()
|
||||||
|
|
||||||
|
assert len(out) == 1
|
||||||
|
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
check_output_consistency()
|
check_output_consistency()
|
||||||
|
check_spec_dec()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -19,12 +19,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||||||
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
|
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
|
||||||
|
|
||||||
|
|
||||||
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, device="cuda"):
|
def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
|
||||||
padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=device)
|
assert q_len <= kv_len
|
||||||
|
|
||||||
|
causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1)
|
||||||
|
|
||||||
|
padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
cur_seq_len = kv_lengths[i].item()
|
cur_seq_len = kv_lengths[i].item()
|
||||||
assert cur_seq_len <= kv_seq_len
|
assert cur_seq_len <= kv_len
|
||||||
padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf")
|
padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")
|
||||||
|
|
||||||
|
padding_mask[:, :, -q_len:, -q_len:] += causal_mask
|
||||||
|
|
||||||
return padding_mask
|
return padding_mask
|
||||||
|
|
||||||
|
|
||||||
@ -33,12 +40,12 @@ def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, de
|
|||||||
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
|
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
|
||||||
def torch_attn_ref(
|
def torch_attn_ref(
|
||||||
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
|
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
|
||||||
k: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
|
k: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
|
||||||
v: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
|
v: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
|
||||||
attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len]
|
attention_mask: torch.Tensor, # [bsz, 1, q_len, kv_len]
|
||||||
bsz: int,
|
bsz: int,
|
||||||
seq_len: int,
|
q_len: int,
|
||||||
kv_seq_len: int,
|
kv_len: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
@ -54,22 +61,24 @@ def torch_attn_ref(
|
|||||||
|
|
||||||
qk = torch.matmul(q, k.transpose(2, 3))
|
qk = torch.matmul(q, k.transpose(2, 3))
|
||||||
attn_scores = qk / (head_dim**0.5)
|
attn_scores = qk / (head_dim**0.5)
|
||||||
assert attn_scores.shape == (bsz, num_heads, seq_len, kv_seq_len), "Invalid shape of attention scores"
|
|
||||||
# for left-side padding
|
|
||||||
if attention_mask.size() != (bsz, 1, seq_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, seq_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_scores = attn_scores + attention_mask
|
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_scores = attn_scores + attention_mask
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
|
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
|
||||||
out = torch.matmul(attn_weights, v)
|
out = torch.matmul(attn_weights, v)
|
||||||
if out.size() != (bsz, num_heads, seq_len, head_dim):
|
if out.size() != (bsz, num_heads, q_len, head_dim):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
|
f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" f" {out.size()}"
|
||||||
)
|
)
|
||||||
out = out.transpose(1, 2).contiguous()
|
out = out.transpose(1, 2).contiguous()
|
||||||
out = out.squeeze(1)
|
out = out.view(-1, out.size(-2), out.size(-1))
|
||||||
|
# out [bsz * q_len, num_heads, head_dim]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,8 +6,8 @@ from colossalai.kernel.triton import flash_decoding_attention
|
|||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||||
convert_kv_unpad_to_padded,
|
convert_kv_unpad_to_padded,
|
||||||
|
create_attention_mask,
|
||||||
generate_caches_and_block_tables_v2,
|
generate_caches_and_block_tables_v2,
|
||||||
prepare_padding_mask,
|
|
||||||
torch_attn_ref,
|
torch_attn_ref,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -21,7 +21,6 @@ except ImportError:
|
|||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
Q_LEN = 1
|
|
||||||
HEAD_DIM = 128
|
HEAD_DIM = 128
|
||||||
|
|
||||||
|
|
||||||
@ -64,6 +63,7 @@ def prepare_data(
|
|||||||
@pytest.mark.parametrize("num_attn_heads", [16])
|
@pytest.mark.parametrize("num_attn_heads", [16])
|
||||||
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
||||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||||
|
@pytest.mark.parametrize("q_len", [1, 5])
|
||||||
def test_flash_decoding(
|
def test_flash_decoding(
|
||||||
bsz: int,
|
bsz: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -71,6 +71,7 @@ def test_flash_decoding(
|
|||||||
num_attn_heads: int,
|
num_attn_heads: int,
|
||||||
kv_group_num: int,
|
kv_group_num: int,
|
||||||
same_context_len: bool,
|
same_context_len: bool,
|
||||||
|
q_len: int,
|
||||||
):
|
):
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -82,47 +83,57 @@ def test_flash_decoding(
|
|||||||
max_seq_len = block_size * max_num_blocks_per_seq
|
max_seq_len = block_size * max_num_blocks_per_seq
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
q, k_unpad, v_unpad, kv_lengths = prepare_data(
|
||||||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, q_len, max_seq_len, dtype, device
|
||||||
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
|
||||||
)
|
)
|
||||||
|
# The maximum sequence length in the batch (if context lengths randomly generated)
|
||||||
|
max_kv_len_in_b = kv_lengths.max().item()
|
||||||
|
|
||||||
|
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
|
||||||
|
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
|
||||||
|
attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
|
||||||
|
out_torch = torch_attn_ref(
|
||||||
|
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
||||||
|
)
|
||||||
|
|
||||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||||
k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||||
)
|
)
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
# The maximum sequence length in the batch (if context lengths randomly generated)
|
|
||||||
max_seq_len_in_b = kv_seq_lengths.max().item()
|
|
||||||
# The maximum block length splitted on kv should be the kv cache block size
|
# The maximum block length splitted on kv should be the kv cache block size
|
||||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size
|
||||||
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
|
output = torch.empty((bsz * q_len, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
|
||||||
mid_output = torch.empty(
|
mid_output = torch.empty(
|
||||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
size=(bsz * q_len, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||||
|
)
|
||||||
|
mid_output_lse = torch.empty(
|
||||||
|
size=(bsz * q_len, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device
|
||||||
)
|
)
|
||||||
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
|
||||||
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
||||||
|
# Here we use different methods to hide the q_len dimension,
|
||||||
|
# refer to attention forward function in modeling.
|
||||||
|
if q_len > 1:
|
||||||
|
q = q.transpose(1, 2).contiguous() # [bsz, q_len, num_heads, head_dim]
|
||||||
|
q = q.view(-1, q.size(-2), q.size(-1)) # [bsz * q_len, num_heads, head_dim]
|
||||||
|
else:
|
||||||
|
q = q.squeeze(2)
|
||||||
|
assert q.shape == (bsz * q_len, num_attn_heads, HEAD_DIM)
|
||||||
|
|
||||||
out_triton = flash_decoding_attention(
|
out_triton = flash_decoding_attention(
|
||||||
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
|
q,
|
||||||
# refer to attention forward in modeling.
|
|
||||||
q.squeeze(2),
|
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
kv_seq_lengths,
|
kv_lengths,
|
||||||
block_tables,
|
block_tables,
|
||||||
block_size,
|
block_size,
|
||||||
max_seq_len_in_b,
|
max_kv_len_in_b,
|
||||||
output,
|
output,
|
||||||
mid_output,
|
mid_output,
|
||||||
mid_output_lse,
|
mid_output_lse,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
kv_group_num=kv_group_num,
|
kv_group_num=kv_group_num,
|
||||||
) # [bsz, 1, num_heads, head_dim]
|
q_len=q_len,
|
||||||
|
) # [bsz * q_len, num_heads, head_dim]
|
||||||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, bsz, max_seq_len_in_b)
|
|
||||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, bsz, max_seq_len_in_b)
|
|
||||||
torch_padding_mask = prepare_padding_mask(kv_seq_lengths, bsz, max_seq_len_in_b, q.device)
|
|
||||||
out_torch = torch_attn_ref(
|
|
||||||
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
|
||||||
)
|
|
||||||
|
|
||||||
assert out_torch.shape == out_triton.shape
|
assert out_torch.shape == out_triton.shape
|
||||||
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
||||||
|
@ -2,7 +2,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache
|
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ except ImportError:
|
|||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
HEAD_DIM = 128
|
HEAD_DIM = 32
|
||||||
|
|
||||||
|
|
||||||
def prepare_data(
|
def prepare_data(
|
||||||
@ -27,15 +27,16 @@ def prepare_data(
|
|||||||
max_num_blocks_per_seq,
|
max_num_blocks_per_seq,
|
||||||
same_context_len,
|
same_context_len,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
device,
|
n=1,
|
||||||
|
device="cuda",
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
):
|
):
|
||||||
# past_kv_seq_lengths in this test records the previous kv seq len
|
assert max_seq_len > n, "max_seq_len must be greater than n"
|
||||||
# (not incorporating the current input whose seq len is 1)
|
|
||||||
past_kv_seq_lengths = (
|
past_kv_seq_lengths = (
|
||||||
torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
|
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||||
if same_context_len
|
if same_context_len
|
||||||
else torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
|
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
|
||||||
)
|
)
|
||||||
num_tokens = torch.sum(past_kv_seq_lengths).item()
|
num_tokens = torch.sum(past_kv_seq_lengths).item()
|
||||||
|
|
||||||
@ -48,14 +49,14 @@ def prepare_data(
|
|||||||
)
|
)
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
|
|
||||||
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||||
new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
new_v = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||||
# mock allocating blocks for the new k/v and update block tables
|
# mock allocating blocks for the new k/v and update block tables
|
||||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
for _ in range(n):
|
||||||
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
past_kv_seq_lengths += 1
|
||||||
|
|
||||||
return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables
|
return new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||||
@ -64,12 +65,9 @@ def prepare_data(
|
|||||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
||||||
@pytest.mark.parametrize("num_kv_heads", [16])
|
@pytest.mark.parametrize("num_kv_heads", [16])
|
||||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||||
|
@pytest.mark.parametrize("n_tokens", [1, 5])
|
||||||
def test_copy_kv_to_caches(
|
def test_copy_kv_to_caches(
|
||||||
bsz: int,
|
bsz: int, block_size: int, max_num_blocks_per_seq: int, num_kv_heads: int, same_context_len: bool, n_tokens: int
|
||||||
block_size: int,
|
|
||||||
max_num_blocks_per_seq: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
same_context_len: bool,
|
|
||||||
):
|
):
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -88,25 +86,49 @@ def test_copy_kv_to_caches(
|
|||||||
max_num_blocks_per_seq,
|
max_num_blocks_per_seq,
|
||||||
same_context_len,
|
same_context_len,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
|
n_tokens,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
# k_cache_torch = k_cache.clone().detach()
|
k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1))
|
||||||
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
|
v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1))
|
||||||
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
k_cache_copy = k_cache.detach().clone()
|
||||||
|
past_kv_seq_lengths = kv_seq_lengths - n_tokens
|
||||||
|
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_lengths // block_size]
|
||||||
|
offsets_in_block = past_kv_seq_lengths % block_size
|
||||||
|
|
||||||
past_kv_seq_len = kv_seq_lengths - 1
|
# Copy k (or v) to k (or v) cache
|
||||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
copy_k_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens)
|
||||||
offsets_in_block = past_kv_seq_len % block_size
|
# Reshape target k from k cache to compare if matching with original tensor
|
||||||
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
# Mainly to handle cases of n_tokens > 1
|
||||||
k_source = new_k.squeeze()
|
k_target = []
|
||||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
for i in range(bsz):
|
||||||
v_source = new_v.squeeze()
|
block_table = block_tables[i]
|
||||||
|
curr_kv_len = past_kv_seq_lengths[i].item()
|
||||||
|
offset = offsets_in_block[i].item()
|
||||||
|
tokens_left = n_tokens
|
||||||
|
while tokens_left > 0:
|
||||||
|
tokens_to_fill = min(block_size - offset, tokens_left)
|
||||||
|
curr_block_id = block_table[curr_kv_len // block_size]
|
||||||
|
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
|
||||||
|
curr_kv_len += tokens_to_fill
|
||||||
|
tokens_left -= tokens_to_fill
|
||||||
|
offset = 0
|
||||||
|
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
|
||||||
|
|
||||||
assert k_target.shape == k_source.shape
|
assert k_target.shape == k_source.shape
|
||||||
assert torch.equal(k_target, k_source)
|
assert torch.equal(k_target, k_source)
|
||||||
assert v_target.shape == v_source.shape
|
|
||||||
assert torch.equal(v_target, v_source)
|
if n_tokens == 1:
|
||||||
|
# Copy k and v to k/v caches
|
||||||
|
k_cache = k_cache_copy
|
||||||
|
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
||||||
|
k_target = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
||||||
|
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||||
|
assert k_target.shape == k_source.shape
|
||||||
|
assert torch.equal(k_target, k_source)
|
||||||
|
assert v_target.shape == v_source.shape
|
||||||
|
assert torch.equal(v_target, v_source)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user