[inference] update examples and engine (#5073)

* update examples and engine

* fix choices

* update example
This commit is contained in:
Xu Kai
2023-11-20 19:44:52 +08:00
committed by GitHub
parent 0c7d8bebd5
commit fb103cfd6e
12 changed files with 107 additions and 273 deletions

View File

@@ -7,7 +7,7 @@ import transformers
from packaging import version
import colossalai
from colossalai.inference import CaiInferEngine
from colossalai.inference import InferenceEngine
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
@@ -36,7 +36,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4)
)
engine = CaiInferEngine(
engine = InferenceEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,

View File

@@ -6,7 +6,7 @@ import torch.distributed as dist
from packaging import version
import colossalai
from colossalai.inference import CaiInferEngine
from colossalai.inference import InferenceEngine
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@@ -44,7 +44,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
)
model = ChatGLMForConditionalGeneration(chatglm_config)
engine = CaiInferEngine(
engine = InferenceEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,

View File

@@ -7,7 +7,7 @@ import transformers
from packaging import version
import colossalai
from colossalai.inference import CaiInferEngine
from colossalai.inference import InferenceEngine
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
@@ -41,7 +41,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
)
)
engine = CaiInferEngine(
engine = InferenceEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,