feat(model): Support gemma-2 model (#1675)

This commit is contained in:
Fangyin Cheng
2024-06-28 13:33:57 +08:00
committed by GitHub
parent cd2163e444
commit 374b6ad151
8 changed files with 108 additions and 8 deletions

View File

@@ -73,6 +73,10 @@ class NewHFChatModelAdapter(LLMModelAdapter, ABC):
) from exc
self.check_dependencies()
logger.info(
f"Load model from {model_path}, from_pretrained_kwargs: {from_pretrained_kwargs}"
)
revision = from_pretrained_kwargs.get("revision", "main")
try:
tokenizer = AutoTokenizer.from_pretrained(
@@ -235,6 +239,43 @@ class GemmaAdapter(NewHFChatModelAdapter):
)
class Gemma2Adapter(NewHFChatModelAdapter):
"""
https://huggingface.co/google/gemma-2-27b-it
https://huggingface.co/google/gemma-2-9b-it
"""
support_4bit: bool = True
support_8bit: bool = True
support_system_message: bool = False
def use_fast_tokenizer(self) -> bool:
return True
def check_transformer_version(self, current_version: str) -> None:
if not current_version >= "4.42.1":
raise ValueError(
"Gemma2 require transformers.__version__>=4.42.1, please upgrade your transformers package."
)
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return (
lower_model_name_or_path
and "gemma-2-" in lower_model_name_or_path
and "it" in lower_model_name_or_path
)
def load(self, model_path: str, from_pretrained_kwargs: dict):
import torch
if not from_pretrained_kwargs:
from_pretrained_kwargs = {}
from_pretrained_kwargs["torch_dtype"] = torch.bfloat16
# from_pretrained_kwargs["revision"] = "float16"
model, tokenizer = super().load(model_path, from_pretrained_kwargs)
return model, tokenizer
class StarlingLMAdapter(NewHFChatModelAdapter):
"""
https://huggingface.co/Nexusflow/Starling-LM-7B-beta
@@ -416,6 +457,17 @@ class DeepseekV2Adapter(NewHFChatModelAdapter):
return model, tokenizer
class DeepseekCoderV2Adapter(DeepseekV2Adapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return (
lower_model_name_or_path
and "deepseek" in lower_model_name_or_path
and "coder" in lower_model_name_or_path
and "v2" in lower_model_name_or_path
and "instruct" in lower_model_name_or_path
)
class SailorAdapter(QwenAdapter):
"""
https://huggingface.co/sail/Sailor-14B-Chat
@@ -520,11 +572,13 @@ register_model_adapter(Yi15Adapter)
register_model_adapter(Mixtral8x7BAdapter)
register_model_adapter(SOLARAdapter)
register_model_adapter(GemmaAdapter)
register_model_adapter(Gemma2Adapter)
register_model_adapter(StarlingLMAdapter)
register_model_adapter(QwenAdapter)
register_model_adapter(QwenMoeAdapter)
register_model_adapter(Llama3Adapter)
register_model_adapter(DeepseekV2Adapter)
register_model_adapter(DeepseekCoderV2Adapter)
register_model_adapter(SailorAdapter)
register_model_adapter(PhiAdapter)
register_model_adapter(SQLCoderAdapter)