[infer] fix test bug (#4838)

* fix test bug

* delete useless code

* fix typo
This commit is contained in:
Xu Kai 2023-10-04 10:01:03 +08:00 committed by GitHub
parent 013a4bedf0
commit d1fcc0fa4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 56 additions and 51 deletions

View File

@ -873,7 +873,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.rotary_pos_emb = RotaryEmbedding( self.rotary_pos_emb = RotaryEmbedding(
rotary_dim // 2, rotary_dim // 2,
original_impl=config.original_rope, # original_impl=config.original_rope, # config has no attribute original_rope
device=device, device=device,
dtype=config.torch_dtype, dtype=config.torch_dtype,
) )

View File

@ -43,7 +43,6 @@ def run_llama_test(args):
tokenizer.pad_token_id = tokenizer.unk_token_id tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half() model = model.half()
model_config = model.config model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)

View File

@ -1,13 +1,14 @@
import pytest import pytest
import torch import torch
from packaging import version from packaging import version
from transformers import BloomForCausalLM
from transformers.models.bloom.configuration_bloom import BloomConfig
import colossalai import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
TP_SIZE = 2 TP_SIZE = 2
MAX_BATCH_SIZE = 4 MAX_BATCH_SIZE = 4
@ -26,21 +27,23 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
], ],
) )
def run(test_config): def run(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom_for_causal_lm") bloom_config = BloomConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): model = BloomForCausalLM(bloom_config)
orig_model = model_fn() model = model.half()
orig_model = orig_model.half()
data = data_gen_fn()
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
) )
infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
generate_kwargs = dict(do_sample=False) input_tokens = {
outputs = infer_engine.generate(data, **generate_kwargs) "input_ids": torch.randint(1, 1000, (MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
"attention_mask": torch.ones((MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
}
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
assert outputs is not None assert outputs is not None
def check_bloom(rank, world_size, port): def check_bloom(rank, world_size, port):

View File

@ -2,17 +2,15 @@ import os
import pytest import pytest
import torch import torch
import torch.distributed as dist
from packaging import version from packaging import version
from transformers import AutoTokenizer
import colossalai import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo.transformers.chatglm2 import infer_config
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 1 TPSIZE = 1
@ -31,28 +29,31 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
], ],
) )
def run_chatglm2_test(test_config): def run_chatglm2_test(test_config):
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) chatglm_config = ChatGLMConfig(
# pad_token_id = 0 num_layers=2,
model_fn = lambda: ChatGLMForConditionalGeneration(infer_config, empty_init=False) vocab_size=1200,
orig_model = model_fn() use_cache=True,
orig_model = orig_model.half() multi_query_attention=True,
text = ["how is the weather today?"] multi_query_group_num=2,
input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True) num_attention_heads=8,
hidden_size=1024,
)
model = ChatGLMForConditionalGeneration(chatglm_config)
model = model.half()
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
) )
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
outputs = infer_engine.generate(input_ids, **generate_kwargs)
assert outputs is not None
# print("outputs.shape: ", outputs[0].shape) input_tokens = {
# print("outputs: ", outputs[0]) "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
if not dist.is_initialized() or dist.get_rank() == 0: "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
for o in outputs: }
output_text = tokenizer.decode(o) outputs = infer_engine.generate(input_tokens, **generate_kwargs)
print(output_text)
assert outputs is not None
def check_chatglm2(rank, world_size, port): def check_chatglm2(rank, world_size, port):

View File

@ -3,13 +3,14 @@ import os
import pytest import pytest
import torch import torch
from packaging import version from packaging import version
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
import colossalai import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 2 TPSIZE = 2
@ -29,21 +30,24 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
], ],
) )
def run_llama_test(test_config): def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm") llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): model = LlamaForCausalLM(llama_config)
orig_model = model_fn() model = model.half()
orig_model = orig_model.half()
data = data_gen_fn()
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
) )
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
init_to_get_rotary(model.model, base=10000)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
generate_kwargs = dict(do_sample=False) input_tokens = {
outputs = infer_engine.generate(data, **generate_kwargs) "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
}
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
assert outputs is not None assert outputs is not None
def check_llama(rank, world_size, port): def check_llama(rank, world_size, port):

View File

@ -38,9 +38,7 @@ def test():
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
o = torch.empty_like() o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda")
# o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
max_kv_cache_len = seq_len max_kv_cache_len = seq_len
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")