mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 04:49:26 +00:00
feat(model): Support Llama-3 (#1436)
This commit is contained in:
@@ -20,6 +20,8 @@ def huggingface_chat_generate_stream(
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
echo = params.get("echo", False)
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_token_ids = params.get("stop_token_ids", [])
|
||||
do_sample = params.get("do_sample", None)
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
# input_ids = input_ids.to(device)
|
||||
@@ -39,13 +41,22 @@ def huggingface_chat_generate_stream(
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, skip_prompt=not echo, skip_special_tokens=True
|
||||
)
|
||||
generate_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
|
||||
base_kwargs = {
|
||||
"max_length": context_len,
|
||||
"temperature": temperature,
|
||||
"streamer": streamer,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
if stop_token_ids:
|
||||
base_kwargs["eos_token_id"] = stop_token_ids
|
||||
if do_sample is not None:
|
||||
base_kwargs["do_sample"] = do_sample
|
||||
|
||||
logger.info(f"Predict with parameters: {base_kwargs}")
|
||||
|
||||
generate_kwargs = {"input_ids": input_ids, **base_kwargs}
|
||||
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
thread.start()
|
||||
out = ""
|
||||
|
Reference in New Issue
Block a user