adapted to pad_context_forward

This commit is contained in:
yuehuayingxueluo
2024-01-09 13:52:53 +08:00
committed by FrankLeeeee
parent 47e53eaa1c
commit fa4fbdbffb
9 changed files with 42 additions and 41 deletions

View File

@@ -51,6 +51,8 @@ class InferenceEngine:
self.model_config = model.config
self.device = torch.device("cuda")
model = model.eval()
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
@@ -85,12 +87,12 @@ class InferenceEngine:
Verify the input config
"""
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
self.tokenizer, PreTrainedTokenizer
):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
assert (
self.model.__class__.__name__ in _supported_models