mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 11:31:12 +00:00
feat(model): Support gemma-2 model (#1675)
This commit is contained in:
@@ -112,6 +112,31 @@ class LLMModelAdapter(ABC):
|
||||
"""Load model and tokenizer"""
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_max_length(self, model, tokenizer) -> Optional[int]:
|
||||
"""Parse the max_length of the model.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The max_length of the model
|
||||
"""
|
||||
if not (tokenizer or model):
|
||||
return None
|
||||
try:
|
||||
model_max_length = None
|
||||
if tokenizer and hasattr(tokenizer, "model_max_length"):
|
||||
model_max_length = tokenizer.model_max_length
|
||||
if model_max_length and model_max_length < 100000000:
|
||||
# Can't be too large
|
||||
return model_max_length
|
||||
if model and hasattr(model, "config"):
|
||||
model_config = model.config
|
||||
if hasattr(model_config, "max_sequence_length"):
|
||||
return model_config.max_sequence_length
|
||||
if hasattr(model_config, "max_position_embeddings"):
|
||||
return model_config.max_position_embeddings
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def load_from_params(self, params):
|
||||
"""Load the model and tokenizer according to the given parameters"""
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user