[Inference] Adapt Baichuan2-13B TP (#5659)

* adapt to baichuan2 13B

* add baichuan2 13B TP

* update baichuan tp logic

* rm unused code

* Fix TP logic

* fix alibi slopes tp logic

* rm nn.Module

* Polished the code.

* change BAICHUAN_MODEL_NAME_OR_PATH

* Modified the logic for loading Baichuan weights.

* fix typos
This commit is contained in:
yuehuayingxueluo 2024-04-30 15:47:07 +08:00 committed by GitHub
parent 808ee6e4ad
commit 5f00002e43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 280 additions and 98 deletions

View File

@ -26,7 +26,7 @@ _ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"baichuan": "<reserved_106>{input_text}<reserved_107>",
"baichuan": " <reserved_106> {input_text} <reserved_107> ",
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
}

View File

@ -112,11 +112,23 @@ class InferenceEngine:
model_policy (Policy): the policy to replace the model
"""
casuallm = None
if isinstance(model_or_path, str):
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
arch = getattr(hf_config, "architectures")[0]
model = _supported_models[arch](hf_config)
if arch in _supported_models.keys():
casuallm = _supported_models[arch](hf_config)
if isinstance(casuallm, AutoModelForCausalLM):
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
model = (
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
)
else:
model = _supported_models[arch](hf_config)
else:
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
self.logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
@ -164,7 +176,7 @@ class InferenceEngine:
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
if isinstance(model_or_path, str):
if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io()

View File

@ -0,0 +1,43 @@
from typing import List, Union
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.shardformer.layer import Linear1D_Col
from colossalai.shardformer.layer.parallel_module import ParallelModule
class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
module.in_features = module.weight.size(1)
module.out_features = module.weight.size(0)
module.bias = None
module.weight.data = nn.functional.normalize(module.weight)
return Linear1D_Col.from_native_module(
module,
process_group,
*args,
**kwargs,
)
class BaichuanWpackLinear1D_Col(Linear1D_Col):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
in_features = module.in_features * 3
out_features = module.out_features // 3
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
module.bias = None
return Linear1D_Col.from_native_module(
module,
process_group,
*args,
**kwargs,
)

View File

@ -1,11 +1,14 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import itertools
import math
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
@ -16,6 +19,18 @@ from colossalai.kernel.triton import (
rotary_embedding,
)
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
logger = get_dist_logger(__name__)
@ -78,14 +93,18 @@ def baichuan_rmsnorm_forward(
return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual)
class NopadBaichuanAttention(nn.Module):
class NopadBaichuanAttention(ParallelModule):
def __init__(
self,
config,
attn_qproj_w: torch.Tensor = None,
attn_kproj_w: torch.Tensor = None,
attn_vproj_w: torch.Tensor = None,
attn_oproj_w: torch.Tensor = None,
attn_oproj: ParallelModule = None,
num_heads: int = None,
hidden_size: int = None,
process_group: ProcessGroup = None,
helper_layout: Layout = None,
):
"""This layer will replace the BaichuanAttention.
@ -94,26 +113,35 @@ class NopadBaichuanAttention(nn.Module):
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
"""
super().__init__()
self.o_proj_weight = attn_oproj_w
ParallelModule.__init__(self)
self.o_proj = attn_oproj
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_dim = self.hidden_size // self.num_heads
self.process_group = process_group
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
self.helper_layout = helper_layout
self.alibi_slopes = None
self.use_alibi_attn = False
if self.hidden_size == 5120:
# Used for Baichuan13B
if config.hidden_size == 5120:
slopes_start = self.process_group.rank() * num_heads
self.use_alibi_attn = True
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
slopes_start : slopes_start + num_heads
].contiguous()
@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBaichuanAttention":
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> "NopadBaichuanAttention":
"""Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention.
Args:
@ -121,24 +149,76 @@ class NopadBaichuanAttention(nn.Module):
"""
config = module.config
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1)
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((3, module.hidden_size, module.hidden_size))
attn_qproj_w = q_proj_w
attn_kproj_w = k_proj_w
attn_vproj_w = v_proj_w
attn_oproj = module.o_proj
attn_qproj_w = q_proj_w.transpose(0, 1)
attn_kproj_w = k_proj_w.transpose(0, 1)
attn_vproj_w = v_proj_w.transpose(0, 1)
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
helper_layout = (
module.W_pack.weight.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
attn_layer = NopadBaichuanAttention(
config=config,
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj_w=attn_oproj_w,
attn_oproj=attn_oproj,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
process_group=process_group,
helper_layout=helper_layout,
)
return attn_layer
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
key = "qkv_weight"
qkv_w = state_dict[prefix + "W_pack.weight"]
in_features = qkv_w.size(1)
out_features = qkv_w.size(0) // 3
qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3)
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec)
qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1)
input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state[key]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(
self,
hidden_states: torch.Tensor,
@ -292,56 +372,38 @@ class NopadBaichuanAttention(nn.Module):
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight)
attn_output = self.o_proj(attn_output)
return attn_output
def extra_repr(self) -> str:
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
# NOTE This will cause difference as out length increases.
class NopadBaichuanMLP(nn.Module):
def __init__(
self,
mlp_gproj_w: torch.Tensor = None,
mlp_uproj_w: torch.Tensor = None,
mlp_dproj_w: torch.Tensor = None,
):
"""This layer will replace the BaichuanAttention.
Args:
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__()
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
self.down_proj_weight = mlp_dproj_w
class NopadBaichuanMLP(NopadLlamaMLP):
@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
"""Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).
Args:
module (nn.Module): The origin MLP(Baichuan) layer.
"""
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
mlp_gproj_w = module.gate_proj.weight
assert is_distributed_tensor(
module.gate_proj.weight
), "gate_proj.weight must be dtensor so we could get the layout of the weight"
mlp_uproj_w = module.up_proj.weight
mlp_dproj = module.down_proj
mlp_layer = NopadBaichuanMLP(
config=None,
mlp_gproj_w=mlp_gproj_w,
mlp_uproj_w=mlp_uproj_w,
mlp_dproj_w=mlp_dproj_w,
mlp_dproj=mlp_dproj,
process_group=process_group,
)
return mlp_layer
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
"""
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return torch.mm(act_out, self.down_proj_weight)

View File

@ -1,6 +1,7 @@
import torch.nn as nn
from torch.nn import Parameter
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
BaichuanLMHeadLinear1D_Col,
BaichuanWpackLinear1D_Col,
)
from colossalai.inference.modeling.models.nopadding_baichuan import (
NopadBaichuanAttention,
NopadBaichuanMLP,
@ -12,6 +13,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
llama_model_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
@ -23,39 +25,72 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
decoder_attribute_replacement = {
"lm_head.weight": Parameter(nn.functional.normalize(self.model.lm_head.weight), requires_grad=False),
}
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
)
else:
decoder_attribute_replacement = None
# used for relpacing Baichuan 7B/13B decoder layer
for layer_name in ["DecoderLayer", "BaichuanLayer"]:
policy[layer_name] = ModulePolicyDescription(
# used for Baichuan 7B and 13B for baichuan DecoderLayer
for DecoderLayer in ["DecoderLayer", "BaichuanLayer"]:
policy[DecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn.W_pack",
target_module=BaichuanWpackLinear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadBaichuanAttention,
),
]
],
)
self.append_or_create_method_replacement(
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=layer_name
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=DecoderLayer
)
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=BaichuanLMHeadLinear1D_Col, kwargs={"gather_output": True}
)
],
)
self.append_or_create_method_replacement(
description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
)
self.append_or_create_method_replacement(
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
)
self.append_or_create_method_replacement(
description={"forward": baichuan_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
)

View File

@ -4,26 +4,29 @@ import random
import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
BAICHUAN_MODEL_NAME_OR_PATH = "/home/data/models/Baichuan2-13B-Base"
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None):
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda()
@ -34,7 +37,6 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
]
output_len = 38
do_sample = do_sample
if do_sample:
top_p = 0.5
@ -45,9 +47,12 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=use_cuda_kernel
max_output_len=output_len,
prompt_template=prompt_template,
use_cuda_kernel=use_cuda_kernel,
tp_size=dist.get_world_size(),
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
@ -70,31 +75,54 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs
@parameterize("prompt_template", [None, "baichuan"])
@parameterize("do_sample", [True, False])
@parameterize("use_cuda_kernel", [True, False])
def check_output_consistency(prompt_template, do_sample, use_cuda_kernel):
cai_outputs = check_inference_engine(
use_engine=True, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
)
transformer_outputs = check_inference_engine(
use_engine=False, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
)
def run_engine(world_size, **kwargs):
manager = Manager()
result_list = manager.list([-1] * world_size) # Create a shared list
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
# clear singleton flash decoding tensors
FDIntermTensors._instances = {}
spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs)
return result_list[0]
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency()
if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)
# NOTE(caidi) If do_sample is set to True or use_cuda_kernel is set to False, the inference result will be different from that of the transformer.
@parameterize("prompt_template", [None, "baichuan"])
@parameterize("do_sample", [False])
@parameterize("use_cuda_kernel", [True])
def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": NoPaddingBaichuanModelInferPolicy(),
"use_cuda_kernel": use_cuda_kernel,
}
kwargs2 = {
"use_engine": False,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": None,
"use_cuda_kernel": use_cuda_kernel,
}
colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
@pytest.mark.skipif(
@ -104,7 +132,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
spawn(run_dist, 1)
test_tp_engine()
if __name__ == "__main__":

View File

@ -193,6 +193,7 @@ def test_vllm_flash_decoding_attention(
max_seq_len_across_batch = kv_seq_lengths.max().item()
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
kv_scale = 1.0
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
@ -250,6 +251,7 @@ def test_vllm_flash_decoding_attention(
max_seq_len_across_batch,
alibi_slopes,
"auto",
kv_scale,
)
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)