mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
* [Hot Fix] CI,import,requirements-test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -57,11 +57,11 @@ class LLMEngine(BaseEngine):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: nn.Module | str,
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy | type[Policy] = None,
|
||||
model_policy: Union[Policy, type[Policy]] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
|
@@ -186,8 +186,6 @@ def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
|
||||
"""
|
||||
|
||||
try:
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
DiffusionPipeline.load_config(model_or_path)
|
||||
return ModelType.DIFFUSION_MODEL
|
||||
except:
|
||||
|
Reference in New Issue
Block a user