mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[inference] update examples and engine (#5073)
* update examples and engine * fix choices * update example
This commit is contained in:
@@ -7,7 +7,7 @@ import transformers
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
@@ -36,7 +36,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4)
|
||||
)
|
||||
|
||||
engine = CaiInferEngine(
|
||||
engine = InferenceEngine(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
|
@@ -6,7 +6,7 @@ import torch.distributed as dist
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
@@ -44,7 +44,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
)
|
||||
model = ChatGLMForConditionalGeneration(chatglm_config)
|
||||
|
||||
engine = CaiInferEngine(
|
||||
engine = InferenceEngine(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
|
@@ -7,7 +7,7 @@ import transformers
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
@@ -41,7 +41,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
)
|
||||
)
|
||||
|
||||
engine = CaiInferEngine(
|
||||
engine = InferenceEngine(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
|
Reference in New Issue
Block a user