mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)
* fix bug * fix * fix multiquery * fix multiquery --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
@@ -13,13 +13,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
import lightllm
|
||||
import lightllm # noqa
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
TPSIZE = 1
|
||||
TPSIZE = 2
|
||||
BATCH_SIZE = 8
|
||||
MAX_INPUT_LEN = 12
|
||||
MAX_OUTPUT_LEN = 100
|
||||
@@ -67,7 +68,10 @@ def check_chatglm2(rank, world_size, port):
|
||||
run_chatglm2_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.skipif(
|
||||
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||
)
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
Reference in New Issue
Block a user