feat(model): Support Llama-3 (#1436)

This commit is contained in:
Fangyin Cheng
2024-04-20 14:07:09 +08:00
committed by GitHub
parent b49b07f011
commit 82e4ce4c43
7 changed files with 69 additions and 5 deletions

View File

@@ -270,9 +270,48 @@ class QwenAdapter(NewHFChatModelAdapter):
)
class Llama3Adapter(NewHFChatModelAdapter):
"""
https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct
"""
support_4bit: bool = True
support_8bit: bool = True
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path and "llama-3" in lower_model_name_or_path
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
convert_to_compatible_format: bool = False,
) -> Optional[str]:
str_prompt = super().get_str_prompt(
params,
messages,
tokenizer,
prompt_template,
convert_to_compatible_format,
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
exist_token_ids = params.get("stop_token_ids", [])
terminators.extend(exist_token_ids)
# TODO(fangyinc): We should modify the params in the future
params["stop_token_ids"] = terminators
return str_prompt
register_model_adapter(YiAdapter)
register_model_adapter(Mixtral8x7BAdapter)
register_model_adapter(SOLARAdapter)
register_model_adapter(GemmaAdapter)
register_model_adapter(StarlingLMAdapter)
register_model_adapter(QwenAdapter)
register_model_adapter(Llama3Adapter)

View File

@@ -20,6 +20,8 @@ def huggingface_chat_generate_stream(
top_p = float(params.get("top_p", 1.0))
echo = params.get("echo", False)
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_token_ids = params.get("stop_token_ids", [])
do_sample = params.get("do_sample", None)
input_ids = tokenizer(prompt).input_ids
# input_ids = input_ids.to(device)
@@ -39,13 +41,22 @@ def huggingface_chat_generate_stream(
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=not echo, skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
base_kwargs = {
"max_length": context_len,
"temperature": temperature,
"streamer": streamer,
"top_p": top_p,
}
if stop_token_ids:
base_kwargs["eos_token_id"] = stop_token_ids
if do_sample is not None:
base_kwargs["do_sample"] = do_sample
logger.info(f"Predict with parameters: {base_kwargs}")
generate_kwargs = {"input_ids": input_ids, **base_kwargs}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
out = ""