feat(model): Support Llama-3 (#1436)

This commit is contained in:
Fangyin Cheng
2024-04-20 14:07:09 +08:00
committed by GitHub
parent b49b07f011
commit 82e4ce4c43
7 changed files with 69 additions and 5 deletions

View File

@@ -20,6 +20,8 @@ def huggingface_chat_generate_stream(
top_p = float(params.get("top_p", 1.0))
echo = params.get("echo", False)
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_token_ids = params.get("stop_token_ids", [])
do_sample = params.get("do_sample", None)
input_ids = tokenizer(prompt).input_ids
# input_ids = input_ids.to(device)
@@ -39,13 +41,22 @@ def huggingface_chat_generate_stream(
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=not echo, skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
base_kwargs = {
"max_length": context_len,
"temperature": temperature,
"streamer": streamer,
"top_p": top_p,
}
if stop_token_ids:
base_kwargs["eos_token_id"] = stop_token_ids
if do_sample is not None:
base_kwargs["do_sample"] = do_sample
logger.info(f"Predict with parameters: {base_kwargs}")
generate_kwargs = {"input_ids": input_ids, **base_kwargs}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
out = ""