feat(model): Support Llama-3 (#1436)

This commit is contained in:
Fangyin Cheng
2024-04-20 14:07:09 +08:00
committed by GitHub
parent b49b07f011
commit 82e4ce4c43
7 changed files with 69 additions and 5 deletions

View File

@@ -270,9 +270,48 @@ class QwenAdapter(NewHFChatModelAdapter):
)
class Llama3Adapter(NewHFChatModelAdapter):
"""
https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct
"""
support_4bit: bool = True
support_8bit: bool = True
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path and "llama-3" in lower_model_name_or_path
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
convert_to_compatible_format: bool = False,
) -> Optional[str]:
str_prompt = super().get_str_prompt(
params,
messages,
tokenizer,
prompt_template,
convert_to_compatible_format,
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
exist_token_ids = params.get("stop_token_ids", [])
terminators.extend(exist_token_ids)
# TODO(fangyinc): We should modify the params in the future
params["stop_token_ids"] = terminators
return str_prompt
register_model_adapter(YiAdapter)
register_model_adapter(Mixtral8x7BAdapter)
register_model_adapter(SOLARAdapter)
register_model_adapter(GemmaAdapter)
register_model_adapter(StarlingLMAdapter)
register_model_adapter(QwenAdapter)
register_model_adapter(Llama3Adapter)