feat(model): Support gemma-2 model (#1675)

This commit is contained in:
Fangyin Cheng
2024-06-28 13:33:57 +08:00
committed by GitHub
parent cd2163e444
commit 374b6ad151
8 changed files with 108 additions and 8 deletions

View File

@@ -112,6 +112,31 @@ class LLMModelAdapter(ABC):
"""Load model and tokenizer"""
raise NotImplementedError
def parse_max_length(self, model, tokenizer) -> Optional[int]:
"""Parse the max_length of the model.
Returns:
Optional[int]: The max_length of the model
"""
if not (tokenizer or model):
return None
try:
model_max_length = None
if tokenizer and hasattr(tokenizer, "model_max_length"):
model_max_length = tokenizer.model_max_length
if model_max_length and model_max_length < 100000000:
# Can't be too large
return model_max_length
if model and hasattr(model, "config"):
model_config = model.config
if hasattr(model_config, "max_sequence_length"):
return model_config.max_sequence_length
if hasattr(model_config, "max_position_embeddings"):
return model_config.max_position_embeddings
return None
except Exception:
return None
def load_from_params(self, params):
"""Load the model and tokenizer according to the given parameters"""
raise NotImplementedError