mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 19:40:13 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
0
dbgpt/model/llm_out/__init__.py
Normal file
0
dbgpt/model/llm_out/__init__.py
Normal file
103
dbgpt/model/llm_out/chatglm_llm.py
Normal file
103
dbgpt/model/llm_out/chatglm_llm.py
Normal file
@@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from typing import List
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from dbgpt.app.scene import ModelMessage, _parse_model_messages
|
||||
|
||||
# TODO move sep to scene prompt of model
|
||||
_CHATGLM_SEP = "\n"
|
||||
_CHATGLM2_SEP = "\n\n"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def chatglm_generate_stream(
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||
):
|
||||
"""Generate text using chatglm model's chat api_v1"""
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
stop = params.get("stop", "###")
|
||||
echo = params.get("echo", False)
|
||||
|
||||
generate_kwargs = {
|
||||
"do_sample": True if temperature > 1e-5 else False,
|
||||
"top_p": top_p,
|
||||
"repetition_penalty": 1.0,
|
||||
"logits_processor": None,
|
||||
}
|
||||
|
||||
if temperature > 1e-5:
|
||||
generate_kwargs["temperature"] = temperature
|
||||
|
||||
# TODO, Fix this
|
||||
# print(prompt)
|
||||
# messages = prompt.split(stop)
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
query, system_messages, hist = _parse_model_messages(messages)
|
||||
system_messages_str = "".join(system_messages)
|
||||
if not hist:
|
||||
# No history conversation, but has system messages, merge to user`s query
|
||||
query = prompt_adaptation(system_messages_str, query)
|
||||
else:
|
||||
# history exist, add system message to head of history
|
||||
hist[0][0] = system_messages_str + _CHATGLM2_SEP + hist[0][0]
|
||||
|
||||
print("Query Message: ", query)
|
||||
print("hist: ", hist)
|
||||
|
||||
for i, (response, new_hist) in enumerate(
|
||||
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
||||
):
|
||||
if echo:
|
||||
output = query + " " + response
|
||||
else:
|
||||
output = response
|
||||
|
||||
yield output
|
||||
|
||||
yield output
|
||||
|
||||
|
||||
class HistoryEntry:
|
||||
def __init__(self, question: str = "", answer: str = ""):
|
||||
self.question = question
|
||||
self.answer = answer
|
||||
|
||||
def add_question(self, question: str):
|
||||
self.question += question
|
||||
|
||||
def add_answer(self, answer: str):
|
||||
self.answer += answer
|
||||
|
||||
def to_list(self):
|
||||
if self.question == "" or self.answer == "":
|
||||
return None
|
||||
return [self.question, self.answer]
|
||||
|
||||
|
||||
def build_history(hist: List[HistoryEntry]) -> List[List[str]]:
|
||||
return list(filter(lambda hl: hl is not None, map(lambda h: h.to_list(), hist)))
|
||||
|
||||
|
||||
def prompt_adaptation(system_messages_str: str, human_message: str) -> str:
|
||||
if not system_messages_str or system_messages_str == "":
|
||||
return human_message
|
||||
# TODO Multi-model prompt adaptation
|
||||
adaptation_rules = [
|
||||
r"Question:\s*{}\s*", # chat_db scene
|
||||
r"Goals:\s*{}\s*", # chat_execution
|
||||
r"问题:\s*{}\s*", # chat_knowledge zh
|
||||
r"question:\s*{}\s*", # chat_knowledge en
|
||||
]
|
||||
# system message has include human question
|
||||
for rule in adaptation_rules:
|
||||
pattern = re.compile(rule.format(re.escape(human_message)))
|
||||
if re.search(pattern, system_messages_str):
|
||||
return system_messages_str
|
||||
# https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
return system_messages_str + _CHATGLM2_SEP + human_message
|
53
dbgpt/model/llm_out/falcon_llm.py
Normal file
53
dbgpt/model/llm_out/falcon_llm.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
|
||||
|
||||
def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
tokenizer.bos_token_id = 1
|
||||
print(params)
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
query = prompt
|
||||
print("Query Message: ", query)
|
||||
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = 1
|
||||
stop_token_ids = [0]
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stop_id in stop_token_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
stop = StopOnTokens()
|
||||
|
||||
generate_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=512,
|
||||
temperature=1.0,
|
||||
do_sample=True,
|
||||
top_k=1,
|
||||
streamer=streamer,
|
||||
repetition_penalty=1.7,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
t.start()
|
||||
|
||||
out = ""
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield out
|
62
dbgpt/model/llm_out/gorilla_llm.py
Normal file
62
dbgpt/model/llm_out/gorilla_llm.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(
|
||||
model, tokenizer, params, device, context_len=42048, stream_interval=2
|
||||
):
|
||||
"""Fork from https://github.com/ShishirPatil/gorilla/blob/main/inference/serve/gorilla_cli.py"""
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
max_new_tokens = int(params.get("max_new_tokens", 1024))
|
||||
stop_str = params.get("stop", None)
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
output_ids = list(input_ids)
|
||||
input_echo_len = len(input_ids)
|
||||
max_src_len = context_len - max_new_tokens - 8
|
||||
input_ids = input_ids[-max_src_len:]
|
||||
past_key_values = out = None
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0:
|
||||
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else:
|
||||
out = model(
|
||||
input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
last_token_logits = logits[0][-1]
|
||||
|
||||
probs = torch.softmax(last_token_logits, dim=-1)
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
output_ids.append(token)
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
stopped = True
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||
tmp_output_ids = output_ids[input_echo_len:]
|
||||
output = tokenizer.decode(
|
||||
tmp_output_ids,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
)
|
||||
pos = output.rfind(stop_str, l_prompt)
|
||||
if pos != -1:
|
||||
output = output[:pos]
|
||||
stopped = True
|
||||
yield output
|
||||
|
||||
if stopped:
|
||||
break
|
||||
|
||||
del past_key_values
|
10
dbgpt/model/llm_out/gpt4all_llm.py
Normal file
10
dbgpt/model/llm_out/gpt4all_llm.py
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
|
||||
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
role, query = prompt.split(stop)[0].split(":")
|
||||
print(f"gpt4all, role: {role}, query: {query}")
|
||||
yield model.generate(prompt=query, streaming=True)
|
109
dbgpt/model/llm_out/guanaco_llm.py
Normal file
109
dbgpt/model/llm_out/guanaco_llm.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import torch
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
|
||||
|
||||
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
|
||||
print(params)
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
query = prompt
|
||||
print("Query Message: ", query)
|
||||
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
stop_token_ids = [0]
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stop_id in stop_token_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
stop = StopOnTokens()
|
||||
|
||||
generate_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=512,
|
||||
temperature=1.0,
|
||||
do_sample=True,
|
||||
top_k=1,
|
||||
streamer=streamer,
|
||||
repetition_penalty=1.7,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
t1.start()
|
||||
|
||||
generator = model.generate(**generate_kwargs)
|
||||
for output in generator:
|
||||
# new_tokens = len(output) - len(input_ids[0])
|
||||
decoded_output = tokenizer.decode(output)
|
||||
if output[-1] in [tokenizer.eos_token_id]:
|
||||
break
|
||||
|
||||
out = decoded_output.split("### Response:")[-1].strip()
|
||||
|
||||
yield out
|
||||
|
||||
|
||||
def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
tokenizer.bos_token_id = 1
|
||||
print(params)
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
max_new_tokens = params.get("max_new_tokens", 512)
|
||||
temerature = params.get("temperature", 1.0)
|
||||
|
||||
query = prompt
|
||||
print("Query Message: ", query)
|
||||
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = 1
|
||||
stop_token_ids = [0]
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stop_id in stop_token_ids:
|
||||
if input_ids[-1][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
stop = StopOnTokens()
|
||||
|
||||
generate_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temerature,
|
||||
do_sample=True,
|
||||
top_k=1,
|
||||
streamer=streamer,
|
||||
repetition_penalty=1.7,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
model.generate(**generate_kwargs)
|
||||
|
||||
out = ""
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield out
|
53
dbgpt/model/llm_out/hf_chat_llm.py
Normal file
53
dbgpt/model/llm_out/hf_chat_llm.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import logging
|
||||
import torch
|
||||
from threading import Thread
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def huggingface_chat_generate_stream(
|
||||
model: AutoModelForCausalLM,
|
||||
tokenizer: AutoTokenizer,
|
||||
params,
|
||||
device,
|
||||
context_len=4096,
|
||||
):
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 0.7))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
echo = params.get("echo", False)
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
# input_ids = input_ids.to(device)
|
||||
if model.config.is_encoder_decoder:
|
||||
max_src_len = context_len
|
||||
else: # truncate
|
||||
max_src_len = context_len - max_new_tokens - 1
|
||||
input_ids = input_ids[-max_src_len:]
|
||||
input_echo_len = len(input_ids)
|
||||
input_ids = torch.as_tensor([input_ids], device=device)
|
||||
|
||||
# messages = params["messages"]
|
||||
# messages = ModelMessage.to_openai_messages(messages)
|
||||
# input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
|
||||
# input_ids = input_ids.to(device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, skip_prompt=not echo, skip_special_tokens=True
|
||||
)
|
||||
generate_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"max_length": context_len,
|
||||
"temperature": temperature,
|
||||
"streamer": streamer,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
thread.start()
|
||||
out = ""
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield out
|
8
dbgpt/model/llm_out/llama_cpp_llm.py
Normal file
8
dbgpt/model/llm_out/llama_cpp_llm.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Dict
|
||||
import torch
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(model, tokenizer, params: Dict, device: str, context_len: int):
|
||||
# Just support LlamaCppModel
|
||||
return model.generate_streaming(params=params, context_len=context_len)
|
38
dbgpt/model/llm_out/proxy_llm.py
Normal file
38
dbgpt/model/llm_out/proxy_llm.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import time
|
||||
|
||||
from dbgpt.model.proxy.llms.chatgpt import chatgpt_generate_stream
|
||||
from dbgpt.model.proxy.llms.bard import bard_generate_stream
|
||||
from dbgpt.model.proxy.llms.claude import claude_generate_stream
|
||||
from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream
|
||||
from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream
|
||||
from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream
|
||||
from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream
|
||||
from dbgpt.model.proxy.llms.spark import spark_generate_stream
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
|
||||
def proxyllm_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
generator_mapping = {
|
||||
"proxyllm": chatgpt_generate_stream,
|
||||
"chatgpt_proxyllm": chatgpt_generate_stream,
|
||||
"bard_proxyllm": bard_generate_stream,
|
||||
"claude_proxyllm": claude_generate_stream,
|
||||
# "gpt4_proxyllm": gpt4_generate_stream, move to chatgpt_generate_stream
|
||||
"wenxin_proxyllm": wenxin_generate_stream,
|
||||
"tongyi_proxyllm": tongyi_generate_stream,
|
||||
"zhipu_proxyllm": zhipu_generate_stream,
|
||||
"bc_proxyllm": baichuan_generate_stream,
|
||||
"spark_proxyllm": spark_generate_stream,
|
||||
}
|
||||
model_params = model.get_params()
|
||||
model_name = model_params.model_name
|
||||
default_error_message = f"{model_name} LLM is not supported"
|
||||
generator_function = generator_mapping.get(
|
||||
model_name, lambda: default_error_message
|
||||
)
|
||||
|
||||
yield from generator_function(model, tokenizer, params, device, context_len)
|
233
dbgpt/model/llm_out/vicuna_base_llm.py
Normal file
233
dbgpt/model/llm_out/vicuna_base_llm.py
Normal file
@@ -0,0 +1,233 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(
|
||||
model, tokenizer, params, device, context_len=4096, stream_interval=2
|
||||
):
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
prompt = prompt.replace("ai:", "assistant:").replace("human:", "user:")
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_str = params.get("stop", None)
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
output_ids = list(input_ids)
|
||||
|
||||
max_src_len = context_len - max_new_tokens - 8
|
||||
input_ids = input_ids[-max_src_len:]
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0:
|
||||
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else:
|
||||
attention_mask = torch.ones(
|
||||
1, past_key_values[0][0].shape[-2] + 1, device=device
|
||||
)
|
||||
out = model(
|
||||
input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
last_token_logits = logits[0][-1]
|
||||
|
||||
if device == "mps":
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
last_token_logits = last_token_logits.float().to("cpu")
|
||||
|
||||
if temperature < 1e-4:
|
||||
token = int(torch.argmax(last_token_logits))
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits / temperature, dim=-1)
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
|
||||
output_ids.append(token)
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
stopped = True
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
pos = output.rfind(stop_str, l_prompt)
|
||||
if pos != -1:
|
||||
output = output[:pos]
|
||||
stopped = True
|
||||
yield output
|
||||
|
||||
if stopped:
|
||||
break
|
||||
|
||||
del past_key_values
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_output(
|
||||
model, tokenizer, params, device, context_len=4096, stream_interval=2
|
||||
):
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
|
||||
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_str = params.get("stop", None)
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
output_ids = list(input_ids)
|
||||
|
||||
max_src_len = context_len - max_new_tokens - 8
|
||||
input_ids = input_ids[-max_src_len:]
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0:
|
||||
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else:
|
||||
attention_mask = torch.ones(
|
||||
1, past_key_values[0][0].shape[-2] + 1, device=device
|
||||
)
|
||||
out = model(
|
||||
input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
last_token_logits = logits[0][-1]
|
||||
|
||||
if device == "mps":
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
last_token_logits = last_token_logits.float().to("cpu")
|
||||
|
||||
if temperature < 1e-4:
|
||||
token = int(torch.argmax(last_token_logits))
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits / temperature, dim=-1)
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
|
||||
output_ids.append(token)
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
stopped = True
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
pos = output.rfind(stop_str, l_prompt)
|
||||
if pos != -1:
|
||||
output = output[:pos]
|
||||
stopped = True
|
||||
return output
|
||||
|
||||
if stopped:
|
||||
break
|
||||
del past_key_values
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_output_ex(
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||
):
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_parameter = params.get("stop", None)
|
||||
|
||||
if stop_parameter == tokenizer.eos_token:
|
||||
stop_parameter = None
|
||||
stop_strings = []
|
||||
if isinstance(stop_parameter, str):
|
||||
stop_strings.append(stop_parameter)
|
||||
elif isinstance(stop_parameter, list):
|
||||
stop_strings = stop_parameter
|
||||
elif stop_parameter is None:
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Stop parameter must be string or list of strings.")
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
output_ids = []
|
||||
|
||||
max_src_len = context_len - max_new_tokens - 8
|
||||
input_ids = input_ids[-max_src_len:]
|
||||
stop_word = None
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0:
|
||||
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else:
|
||||
out = model(
|
||||
input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
last_token_logits = logits[0][-1]
|
||||
|
||||
if temperature < 1e-4:
|
||||
token = int(torch.argmax(last_token_logits))
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits / temperature, dim=-1)
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
|
||||
output_ids.append(token)
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
stopped = True
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
# print("Partial output:", output)
|
||||
for stop_str in stop_strings:
|
||||
# print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END")
|
||||
pos = output.rfind(stop_str)
|
||||
if pos != -1:
|
||||
# print("Found stop str: ", output)
|
||||
output = output[:pos]
|
||||
# print("Trimmed output: ", output)
|
||||
stopped = True
|
||||
stop_word = stop_str
|
||||
break
|
||||
else:
|
||||
pass
|
||||
# print("Not found")
|
||||
|
||||
if stopped:
|
||||
break
|
||||
|
||||
del past_key_values
|
||||
if pos != -1:
|
||||
return output[:pos]
|
||||
return output
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_embeddings(model, tokenizer, prompt):
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
input_embeddings = model.get_input_embeddings().to(device)
|
||||
|
||||
embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device))
|
||||
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
||||
return mean.to(device)
|
93
dbgpt/model/llm_out/vicuna_llm.py
Normal file
93
dbgpt/model/llm_out/vicuna_llm.py
Normal file
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import json
|
||||
from typing import Any, List, Mapping, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.base import LLM
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class VicunaLLM(LLM):
|
||||
vicuna_generate_path = "generate_stream"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
max_new_tokens: int,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"stop": stop,
|
||||
}
|
||||
response = requests.post(
|
||||
url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path),
|
||||
data=json.dumps(params),
|
||||
)
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
yield output
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "custome"
|
||||
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
class VicunaEmbeddingLLM(BaseModel, Embeddings):
|
||||
vicuna_embedding_path = "embedding"
|
||||
|
||||
def _call(self, prompt: str) -> str:
|
||||
p = prompt.strip()
|
||||
print("Sending prompt ", p)
|
||||
|
||||
response = requests.post(
|
||||
url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path),
|
||||
json={"prompt": p},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["response"]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Vicuna's server embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of text to embed
|
||||
|
||||
Returns:
|
||||
List of embeddings. one for each text.
|
||||
"""
|
||||
results = []
|
||||
for text in texts:
|
||||
response = self.embed_query(text)
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Vicuna's server embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
Returns:
|
||||
Embedding for the text
|
||||
"""
|
||||
embedding = self._call(text)
|
||||
return embedding
|
95
dbgpt/model/llm_out/vllm_llm.py
Normal file
95
dbgpt/model/llm_out/vllm_llm.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from typing import Dict
|
||||
import os
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||
|
||||
|
||||
async def generate_stream(
|
||||
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int
|
||||
):
|
||||
"""
|
||||
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py
|
||||
"""
|
||||
prompt = params["prompt"]
|
||||
request_id = params.pop("request_id") if "request_id" in params else random_uuid()
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
echo = bool(params.get("echo", True))
|
||||
stop_str = params.get("stop", None)
|
||||
|
||||
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||
if tokenizer.eos_token_id is not None:
|
||||
stop_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
# Handle stop_str
|
||||
stop = set()
|
||||
if isinstance(stop_str, str) and stop_str != "":
|
||||
stop.add(stop_str)
|
||||
elif isinstance(stop_str, list) and stop_str != []:
|
||||
stop.update(stop_str)
|
||||
|
||||
for tid in stop_token_ids:
|
||||
if tid is not None:
|
||||
stop.add(tokenizer.decode(tid))
|
||||
|
||||
# make sampling params in vllm
|
||||
top_p = max(top_p, 1e-5)
|
||||
if temperature <= 1e-5:
|
||||
top_p = 1.0
|
||||
gen_params = {
|
||||
"stop": list(stop),
|
||||
"ignore_eos": False,
|
||||
}
|
||||
prompt_token_ids = None
|
||||
if _IS_BENCHMARK:
|
||||
gen_params["stop"] = []
|
||||
gen_params["ignore_eos"] = True
|
||||
prompt_len = context_len - max_new_tokens - 2
|
||||
prompt_token_ids = tokenizer([prompt]).input_ids[0]
|
||||
prompt_token_ids = prompt_token_ids[-prompt_len:]
|
||||
sampling_params = SamplingParams(
|
||||
n=1,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
use_beam_search=False,
|
||||
max_tokens=max_new_tokens,
|
||||
**gen_params
|
||||
)
|
||||
|
||||
results_generator = model.generate(
|
||||
prompt, sampling_params, request_id, prompt_token_ids=prompt_token_ids
|
||||
)
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
if echo:
|
||||
text_outputs = [prompt + output.text for output in request_output.outputs]
|
||||
else:
|
||||
text_outputs = [output.text for output in request_output.outputs]
|
||||
text_outputs = " ".join(text_outputs)
|
||||
|
||||
# Note: usage is not supported yet
|
||||
prompt_tokens = len(request_output.prompt_token_ids)
|
||||
completion_tokens = sum(
|
||||
len(output.token_ids) for output in request_output.outputs
|
||||
)
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
finish_reason = (
|
||||
request_output.outputs[0].finish_reason
|
||||
if len(request_output.outputs) == 1
|
||||
else [output.finish_reason for output in request_output.outputs]
|
||||
)
|
||||
yield {
|
||||
"text": text_outputs,
|
||||
"error_code": 0,
|
||||
"usage": usage,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
Reference in New Issue
Block a user