mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[hotfix] Suport extra_kwargs in ShardConfig (#5031)
* [refactor]: replace inference args with extra_kwargs in ShardConfig * modify shardconfig * polish code * fix policy bug in llama * fix bug in auto policy * remove setattr in ShardConfig
This commit is contained in:
@@ -19,7 +19,7 @@ def build_model(
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused,
|
||||
inference_only=True,
|
||||
extra_kwargs={"inference_only": True},
|
||||
)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
|
@@ -11,11 +11,10 @@ from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
import lightllm
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
|
||||
TP_SIZE = 2
|
||||
MAX_BATCH_SIZE = 4
|
||||
MAX_INPUT_LEN = 16
|
||||
@@ -38,7 +37,7 @@ def run(test_config):
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
@@ -58,7 +57,10 @@ def check_bloom(rank, world_size, port):
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.skipif(
|
||||
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||
)
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@@ -49,7 +49,7 @@ def run_chatglm2_test(test_config):
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
|
@@ -34,7 +34,7 @@ def run():
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True)
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=False, extra_kwargs={"inference_only": True})
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
dynamic_batch_manager = DynamicBatchManager(
|
||||
|
@@ -57,7 +57,9 @@ def run():
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if TP_SIZE > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
||||
|
@@ -36,7 +36,7 @@ def run(test_config):
|
||||
|
||||
# 1. check TPInferEngine init and model optimization
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
|
@@ -13,11 +13,10 @@ from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
import lightllm
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
TPSIZE = 2
|
||||
BATCH_SIZE = 8
|
||||
@@ -43,7 +42,7 @@ def run_llama_test(test_config):
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
@@ -63,7 +62,10 @@ def check_llama(rank, world_size, port):
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.skipif(
|
||||
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||
)
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@@ -13,7 +13,6 @@ from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
import lightllm
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
@@ -41,7 +40,7 @@ def run_llama_test(test_config):
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
@@ -61,7 +60,10 @@ def check_llama(rank, world_size, port):
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.skipif(
|
||||
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||
)
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
Reference in New Issue
Block a user