mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[Inference]Adapt to baichuan2 13B (#5614)
* adapt to baichuan2 13B * adapt to baichuan2 13B * change BAICHUAN_MODEL_NAME_OR_PATH * fix test_decoding_attn.py * Modifications based on review comments. * change BAICHUAN_MODEL_NAME_OR_PATH * mv attn mask processes to test flash decoding * mv get_alibi_slopes baichuan modeling * fix bugs in test_baichuan.py
This commit is contained in:
@@ -12,7 +12,8 @@ from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
|
||||
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
|
||||
BAICHUAN_MODEL_NAME_OR_PATH = "/home/data/models/Baichuan2-13B-Base"
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
@@ -22,12 +23,10 @@ def setup_seed(seed):
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
BAICHUAN_MODEL_NAME_OR_PATH, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
|
||||
).cuda()
|
||||
model = AutoModelForCausalLM.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda()
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
@@ -35,17 +34,24 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
]
|
||||
|
||||
output_len = 38
|
||||
do_sample = False
|
||||
do_sample = do_sample
|
||||
|
||||
if do_sample:
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
else:
|
||||
top_p = None
|
||||
top_k = None
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(
|
||||
max_output_len=output_len, prompt_template=prompt_template, dtype="fp32", use_cuda_kernel=True
|
||||
max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=use_cuda_kernel
|
||||
)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample)
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
@@ -57,6 +63,8 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
inputs = inputs.cuda()
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
max_new_tokens=output_len,
|
||||
)
|
||||
@@ -67,9 +75,15 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
|
||||
|
||||
@parameterize("prompt_template", [None, "baichuan"])
|
||||
def check_output_consistency(prompt_template):
|
||||
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
|
||||
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
|
||||
@parameterize("do_sample", [True, False])
|
||||
@parameterize("use_cuda_kernel", [True, False])
|
||||
def check_output_consistency(prompt_template, do_sample, use_cuda_kernel):
|
||||
cai_outputs = check_inference_engine(
|
||||
use_engine=True, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
|
||||
)
|
||||
transformer_outputs = check_inference_engine(
|
||||
use_engine=False, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
|
||||
)
|
||||
|
||||
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
|
||||
|
Reference in New Issue
Block a user