mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-05 06:58:09 +00:00
[Inference] Fix flash-attn import and add model test (#5794)
* Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
parent
aac941ef78
commit
8554585a5f
@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
|
|
||||||
from colossalai.inference.config import ModelShardInferenceConfig
|
from colossalai.inference.config import ModelShardInferenceConfig
|
||||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
@ -44,7 +43,7 @@ class CudaAttentionBackend(AttentionBackend):
|
|||||||
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, use_flash_attn: bool):
|
def __init__(self, use_flash_attn: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inference_ops = InferenceOpsLoader().load()
|
self.inference_ops = InferenceOpsLoader().load()
|
||||||
self.use_flash_attn = use_flash_attn
|
self.use_flash_attn = use_flash_attn
|
||||||
@ -52,6 +51,9 @@ class CudaAttentionBackend(AttentionBackend):
|
|||||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
if self.use_flash_attn:
|
if self.use_flash_attn:
|
||||||
token_nums = kwargs.get("token_nums", -1)
|
token_nums = kwargs.get("token_nums", -1)
|
||||||
|
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_func(
|
attn_output = flash_attn_varlen_func(
|
||||||
attn_metadata.query_states,
|
attn_metadata.query_states,
|
||||||
attn_metadata.key_states,
|
attn_metadata.key_states,
|
||||||
|
@ -200,8 +200,6 @@ class NopadBaichuanAttention(ParallelModule):
|
|||||||
|
|
||||||
self.pre_attention_backend.decode(
|
self.pre_attention_backend.decode(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
cos=cos_sin[0],
|
|
||||||
sin=cos_sin[1],
|
|
||||||
q_len=q_len,
|
q_len=q_len,
|
||||||
)
|
)
|
||||||
attn_output = self.attention_backend.decode(
|
attn_output = self.attention_backend.decode(
|
||||||
|
@ -114,7 +114,7 @@ def llama_model_forward(
|
|||||||
|
|
||||||
elif use_cuda_kernel:
|
elif use_cuda_kernel:
|
||||||
if can_use_flash_attn2(inputmetadata.dtype):
|
if can_use_flash_attn2(inputmetadata.dtype):
|
||||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
|
||||||
hidden_dim = self._cos_cached.size(-1)
|
hidden_dim = self._cos_cached.size(-1)
|
||||||
total_length = hidden_states.size(0)
|
total_length = hidden_states.size(0)
|
||||||
@ -265,7 +265,7 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
|||||||
mlp_dproj: ParallelModule = None,
|
mlp_dproj: ParallelModule = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
):
|
):
|
||||||
"""A Unified Layer for
|
"""Replacement of LlamaMLP layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (LlamaConfig): Holding the Llama model config.
|
config (LlamaConfig): Holding the Llama model config.
|
||||||
|
@ -152,6 +152,8 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from flash_attn import flash_attn_varlen_func # noqa
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||||
|
@ -50,7 +50,7 @@ def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.T
|
|||||||
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
return max_seqlen_in_batch, cu_seqlens, indices
|
return max_seqlen_in_batch, cu_seqlens, indices
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ def prepare_data(
|
|||||||
num_tokens = torch.sum(context_lengths).item()
|
num_tokens = torch.sum(context_lengths).item()
|
||||||
|
|
||||||
max_seq_len_in_batch = context_lengths.max()
|
max_seq_len_in_batch = context_lengths.max()
|
||||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
|
||||||
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
||||||
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||||
|
161
tests/test_infer/test_models/test_custom_model.py
Normal file
161
tests/test_infer/test_models/test_custom_model.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.multiprocessing import Manager
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
import colossalai.inference.modeling.policy as policy
|
||||||
|
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
# NOTE: To test a model with the inference engine, you need to provide the path to your
|
||||||
|
# local pretrained model weights in the MODEL_MAP dictionary
|
||||||
|
MODEL_MAP = {
|
||||||
|
"baichuan": {
|
||||||
|
"model": AutoModelForCausalLM,
|
||||||
|
"tokenizer": AutoTokenizer,
|
||||||
|
"policy": policy.NoPaddingBaichuanModelInferPolicy,
|
||||||
|
"model_name_or_path": "baichuan-inc/Baichuan2-13B-Base", # provide the path to local model weights
|
||||||
|
},
|
||||||
|
"llama": {
|
||||||
|
"model": LlamaForCausalLM,
|
||||||
|
"tokenizer": LlamaTokenizer,
|
||||||
|
"policy": policy.NoPaddingLlamaModelInferPolicy,
|
||||||
|
"model_name_or_path": "meta-llama/Llama-2-70b-hf",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
MODELS_TO_TEST = ["llama", "baichuan"] # Specify the models to test
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("model", MODELS_TO_TEST)
|
||||||
|
@parameterize("prompt_template", [None, "model_specific"])
|
||||||
|
@parameterize("do_sample", [False])
|
||||||
|
@parameterize("use_cuda_kernel", [True])
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_model(model, prompt_template, do_sample, use_cuda_kernel):
|
||||||
|
model_path = MODEL_MAP[model]["model_name_or_path"]
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
pytest.skip(
|
||||||
|
f"There is no local model address included for {model}, please replace this address with a valid one."
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_template == "model_specific":
|
||||||
|
prompt_template = model
|
||||||
|
|
||||||
|
model_config = MODEL_MAP[model]
|
||||||
|
|
||||||
|
kwargs1 = {
|
||||||
|
"model": model,
|
||||||
|
"use_engine": True,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"policy": model_config["policy"](),
|
||||||
|
"use_cuda_kernel": use_cuda_kernel,
|
||||||
|
}
|
||||||
|
|
||||||
|
kwargs2 = {
|
||||||
|
"model": model,
|
||||||
|
"use_engine": False,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"policy": None,
|
||||||
|
"use_cuda_kernel": use_cuda_kernel,
|
||||||
|
}
|
||||||
|
|
||||||
|
colossal_tp_1_output = run_engine(1, **kwargs1)
|
||||||
|
colossal_tp_2_output = run_engine(2, **kwargs1)
|
||||||
|
transformer_tp_1_output = run_engine(1, **kwargs2)
|
||||||
|
|
||||||
|
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
|
||||||
|
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
|
||||||
|
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
|
||||||
|
|
||||||
|
|
||||||
|
def run_engine(world_size, **kwargs):
|
||||||
|
manager = Manager()
|
||||||
|
result_list = manager.list([-1] * world_size) # Create a shared list
|
||||||
|
spawn(run_dist, world_size, func_to_run=_run_engine, ret=result_list, **kwargs)
|
||||||
|
return result_list[0]
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
|
||||||
|
if ret:
|
||||||
|
ret[rank] = func_to_run(**kwargs)
|
||||||
|
else:
|
||||||
|
func_to_run(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_engine(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
|
||||||
|
setup_seed(20)
|
||||||
|
model_config = MODEL_MAP[model]
|
||||||
|
model_name_or_path = model_config["model_name_or_path"]
|
||||||
|
tokenizer = model_config["tokenizer"].from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True)
|
||||||
|
model = model_config["model"].from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
"Introduce some landmarks in Paris:",
|
||||||
|
]
|
||||||
|
|
||||||
|
output_len = 38
|
||||||
|
|
||||||
|
if do_sample:
|
||||||
|
top_p = 0.5
|
||||||
|
top_k = 50
|
||||||
|
else:
|
||||||
|
top_p = None
|
||||||
|
top_k = None
|
||||||
|
|
||||||
|
if use_engine:
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
max_output_len=output_len,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
use_cuda_kernel=use_cuda_kernel,
|
||||||
|
tp_size=dist.get_world_size(),
|
||||||
|
)
|
||||||
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
|
||||||
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
|
inference_engine.add_request(prompts=inputs)
|
||||||
|
assert inference_engine.request_handler._has_waiting()
|
||||||
|
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
|
||||||
|
outputs = inference_engine.generate(generation_config=generation_config)
|
||||||
|
else:
|
||||||
|
if prompt_template:
|
||||||
|
# apply prompt template
|
||||||
|
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||||
|
inputs = inputs.cuda()
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
do_sample=do_sample,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
max_new_tokens=output_len,
|
||||||
|
)
|
||||||
|
outputs = model.generate(inputs, generation_config=generation_config)
|
||||||
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def setup_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_model()
|
Loading…
Reference in New Issue
Block a user