[Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)

* fix bug

* fix

* fix multiquery

* fix multiquery

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
Jianghai
2023-11-07 15:01:50 +08:00
committed by GitHub
parent c36e782d80
commit ef4c14a5e2
8 changed files with 21 additions and 19 deletions

View File

@@ -13,13 +13,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
try:
import lightllm
import lightllm # noqa
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 1
TPSIZE = 2
BATCH_SIZE = 8
MAX_INPUT_LEN = 12
MAX_OUTPUT_LEN = 100
@@ -67,7 +68,10 @@ def check_chatglm2(rank, world_size, port):
run_chatglm2_test()
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()