mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
adapted to pad_context_forward
This commit is contained in:
committed by
FrankLeeeee
parent
47e53eaa1c
commit
fa4fbdbffb
@@ -51,6 +51,8 @@ class InferenceEngine:
|
||||
self.model_config = model.config
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
model = model.eval()
|
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
||||
@@ -85,12 +87,12 @@ class InferenceEngine:
|
||||
Verify the input config
|
||||
"""
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
|
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
||||
self.tokenizer, PreTrainedTokenizer
|
||||
):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
assert (
|
||||
self.model.__class__.__name__ in _supported_models
|
||||
|
Reference in New Issue
Block a user