From 66abf1c6e89860b55e2f26a847dd86f8fecfc863 Mon Sep 17 00:00:00 2001 From: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Date: Mon, 8 Jul 2024 22:32:06 +0800 Subject: [PATCH] [HotFix] CI,import,requirements-test for #5838 (#5892) * [Hot Fix] CI,import,requirements-test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/inference/core/llm_engine.py | 6 +++--- colossalai/inference/utils.py | 2 -- examples/inference/stable_diffusion/test_ci.sh | 2 ++ requirements/requirements-test.txt | 1 - 4 files changed, 5 insertions(+), 6 deletions(-) create mode 100644 examples/inference/stable_diffusion/test_ci.sh diff --git a/colossalai/inference/core/llm_engine.py b/colossalai/inference/core/llm_engine.py index b973d371d..1dbc3ace8 100644 --- a/colossalai/inference/core/llm_engine.py +++ b/colossalai/inference/core/llm_engine.py @@ -57,11 +57,11 @@ class LLMEngine(BaseEngine): def __init__( self, - model_or_path: nn.Module | str, - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None, + model_or_path: Union[nn.Module, str], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, inference_config: InferenceConfig = None, verbose: bool = False, - model_policy: Policy | type[Policy] = None, + model_policy: Union[Policy, type[Policy]] = None, ) -> None: self.inference_config = inference_config self.dtype = inference_config.dtype diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index f2a0fc037..d0851e362 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -186,8 +186,6 @@ def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]): """ try: - from diffusers import DiffusionPipeline - DiffusionPipeline.load_config(model_or_path) return ModelType.DIFFUSION_MODEL except: diff --git a/examples/inference/stable_diffusion/test_ci.sh b/examples/inference/stable_diffusion/test_ci.sh new file mode 100644 index 000000000..d0189431c --- /dev/null +++ b/examples/inference/stable_diffusion/test_ci.sh @@ -0,0 +1,2 @@ +#!/bin/bash +echo "Skip the test (this test is slow)" diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e4affc7f5..93a3690fe 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,4 +1,3 @@ -diffusers pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon