[Inference] Adapt Baichuan2-13B TP (#5659)

* adapt to baichuan2 13B

* add baichuan2 13B TP

* update baichuan tp logic

* rm unused code

* Fix TP logic

* fix alibi slopes tp logic

* rm nn.Module

* Polished the code.

* change BAICHUAN_MODEL_NAME_OR_PATH

* Modified the logic for loading Baichuan weights.

* fix typos
This commit is contained in:
yuehuayingxueluo
2024-04-30 15:47:07 +08:00
committed by GitHub
parent 808ee6e4ad
commit 5f00002e43
7 changed files with 280 additions and 98 deletions

View File

@@ -4,26 +4,29 @@ import random
import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
BAICHUAN_MODEL_NAME_OR_PATH = "/home/data/models/Baichuan2-13B-Base"
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None):
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda()
@@ -34,7 +37,6 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
]
output_len = 38
do_sample = do_sample
if do_sample:
top_p = 0.5
@@ -45,9 +47,12 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=use_cuda_kernel
max_output_len=output_len,
prompt_template=prompt_template,
use_cuda_kernel=use_cuda_kernel,
tp_size=dist.get_world_size(),
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
@@ -70,31 +75,54 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs
@parameterize("prompt_template", [None, "baichuan"])
@parameterize("do_sample", [True, False])
@parameterize("use_cuda_kernel", [True, False])
def check_output_consistency(prompt_template, do_sample, use_cuda_kernel):
cai_outputs = check_inference_engine(
use_engine=True, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
)
transformer_outputs = check_inference_engine(
use_engine=False, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
)
def run_engine(world_size, **kwargs):
manager = Manager()
result_list = manager.list([-1] * world_size) # Create a shared list
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
# clear singleton flash decoding tensors
FDIntermTensors._instances = {}
spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs)
return result_list[0]
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency()
if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)
# NOTE(caidi) If do_sample is set to True or use_cuda_kernel is set to False, the inference result will be different from that of the transformer.
@parameterize("prompt_template", [None, "baichuan"])
@parameterize("do_sample", [False])
@parameterize("use_cuda_kernel", [True])
def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": NoPaddingBaichuanModelInferPolicy(),
"use_cuda_kernel": use_cuda_kernel,
}
kwargs2 = {
"use_engine": False,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": None,
"use_cuda_kernel": use_cuda_kernel,
}
colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
@pytest.mark.skipif(
@@ -104,7 +132,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
spawn(run_dist, 1)
test_tp_engine()
if __name__ == "__main__":

View File

@@ -193,6 +193,7 @@ def test_vllm_flash_decoding_attention(
max_seq_len_across_batch = kv_seq_lengths.max().item()
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
kv_scale = 1.0
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
@@ -250,6 +251,7 @@ def test_vllm_flash_decoding_attention(
max_seq_len_across_batch,
alibi_slopes,
"auto",
kv_scale,
)
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)