mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 04:36:23 +00:00
feat(model): Support llama.cpp server deploy (#2263)
This commit is contained in:
parent
576da34e92
commit
0b2af2e9a2
@ -65,6 +65,14 @@ class APIChatCompletionRequest(BaseModel):
|
||||
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
"""Usage info entity."""
|
||||
|
||||
prompt_tokens: int = Field(0, description="Prompt tokens")
|
||||
total_tokens: int = Field(0, description="Total tokens")
|
||||
completion_tokens: Optional[int] = Field(0, description="Completion tokens")
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
"""Delta message entity for chat completion response."""
|
||||
|
||||
@ -95,6 +103,7 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
choices: List[ChatCompletionResponseStreamChoice] = Field(
|
||||
..., description="Chat completion response choices"
|
||||
)
|
||||
usage: UsageInfo = Field(..., description="Usage info")
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@ -104,14 +113,6 @@ class ChatMessage(BaseModel):
|
||||
content: str = Field(..., description="Content of the message")
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
"""Usage info entity."""
|
||||
|
||||
prompt_tokens: int = Field(0, description="Prompt tokens")
|
||||
total_tokens: int = Field(0, description="Total tokens")
|
||||
completion_tokens: Optional[int] = Field(0, description="Completion tokens")
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
"""Chat completion response choice entity."""
|
||||
|
||||
@ -256,3 +257,157 @@ class ErrorCode(IntEnum):
|
||||
GRADIO_STREAM_UNKNOWN_ERROR = 50004
|
||||
CONTROLLER_NO_WORKER = 50005
|
||||
CONTROLLER_WORKER_TIMEOUT = 50006
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""Completion request entity."""
|
||||
|
||||
model: str = Field(..., description="Model name")
|
||||
prompt: Union[str, List[Any]] = Field(
|
||||
...,
|
||||
description="Provide the prompt for this completion as a string or as an "
|
||||
"array of strings or numbers representing tokens",
|
||||
)
|
||||
suffix: Optional[str] = Field(
|
||||
None,
|
||||
description="Suffix to append to the completion. If provided, the model will "
|
||||
"stop generating upon reaching this suffix",
|
||||
)
|
||||
temperature: Optional[float] = Field(
|
||||
0.8,
|
||||
description="Adjust the randomness of the generated text. Default: `0.8`",
|
||||
)
|
||||
n: Optional[int] = Field(
|
||||
1,
|
||||
description="Number of completions to generate. Default: `1`",
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
16,
|
||||
description="The maximum number of tokens that can be generated in the "
|
||||
"completion. Default: `16`",
|
||||
)
|
||||
stop: Optional[Union[str, List[str]]] = Field(
|
||||
None,
|
||||
description="Up to 4 sequences where the API will stop generating further "
|
||||
"tokens. The returned text will not contain the stop sequence.",
|
||||
)
|
||||
stream: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to stream back partial completions. Default: `False`",
|
||||
)
|
||||
top_p: Optional[float] = Field(
|
||||
1.0,
|
||||
description="Limit the next token selection to a subset of tokens with a "
|
||||
"cumulative probability above a threshold P. Default: `1.0`",
|
||||
)
|
||||
top_k: Optional[int] = Field(
|
||||
-1,
|
||||
description="Limit the next token selection to the K most probable tokens. "
|
||||
"Default: `-1`",
|
||||
)
|
||||
logprobs: Optional[int] = Field(
|
||||
None,
|
||||
description="Modify the likelihood of specified tokens appearing in the "
|
||||
"completion.",
|
||||
)
|
||||
echo: Optional[bool] = Field(
|
||||
False, description="Echo back the prompt in addition to the completion"
|
||||
)
|
||||
presence_penalty: Optional[float] = Field(
|
||||
0.0,
|
||||
description="Number between -2.0 and 2.0. Positive values penalize new tokens "
|
||||
"based on whether they appear in the text so far, increasing the model's "
|
||||
"likelihood to talk about new topics.",
|
||||
)
|
||||
frequency_penalty: Optional[float] = Field(
|
||||
0.0,
|
||||
description="Number between -2.0 and 2.0. Positive values penalize new tokens "
|
||||
"based on their existing frequency in the text so far, decreasing the model's "
|
||||
"likelihood to repeat the same line verbatim.",
|
||||
)
|
||||
user: Optional[str] = Field(
|
||||
None,
|
||||
description="A unique identifier representing your end-user, which can help "
|
||||
"OpenAI to monitor and detect abuse.",
|
||||
)
|
||||
use_beam_search: Optional[bool] = False
|
||||
best_of: Optional[int] = Field(
|
||||
1,
|
||||
description='Generates best_of completions server-side and returns the "best" '
|
||||
"(the one with the highest log probability per token). Results cannot be "
|
||||
"streamed. When used with n, best_of controls the number of candidate "
|
||||
"completions and n specifies how many to return – best_of must be greater than "
|
||||
"n.",
|
||||
)
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
"""Logprobs entity."""
|
||||
|
||||
text_offset: List[int] = Field(default_factory=list, description="Text offset")
|
||||
token_logprobs: List[Optional[float]] = Field(
|
||||
default_factory=list, description="Token logprobs"
|
||||
)
|
||||
tokens: List[str] = Field(default_factory=list, description="Tokens")
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(
|
||||
default_factory=list, description="Top logprobs"
|
||||
)
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
"""Completion response choice entity."""
|
||||
|
||||
index: int = Field(..., description="Choice index")
|
||||
text: str = Field(..., description="Text")
|
||||
logprobs: Optional[LogProbs] = Field(None, description="Logprobs")
|
||||
finish_reason: Optional[Literal["stop", "length"]] = Field(
|
||||
None, description="The reason the model stopped generating tokens."
|
||||
)
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Completion response entity."""
|
||||
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid1())}")
|
||||
object: str = Field(
|
||||
"text_completion",
|
||||
description="The object type, which is always 'text_completion'",
|
||||
)
|
||||
created: int = Field(
|
||||
default_factory=lambda: int(time.time()), description="Created time"
|
||||
)
|
||||
model: str = Field(..., description="Model name")
|
||||
choices: List[CompletionResponseChoice] = Field(
|
||||
...,
|
||||
description="The list of completion choices the model generated for the input "
|
||||
"prompt.",
|
||||
)
|
||||
usage: UsageInfo = Field(..., description="Usage info")
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(BaseModel):
|
||||
"""Completion response choice entity."""
|
||||
|
||||
index: int = Field(..., description="Choice index")
|
||||
text: str = Field(..., description="Text")
|
||||
logprobs: Optional[LogProbs] = Field(None, description="Logprobs")
|
||||
finish_reason: Optional[Literal["stop", "length"]] = Field(
|
||||
None, description="The reason the model stopped generating tokens."
|
||||
)
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
"""Completion stream response entity."""
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: f"cmpl-{str(uuid.uuid1())}", description="Stream ID"
|
||||
)
|
||||
object: str = Field("text_completion", description="Object type")
|
||||
created: int = Field(
|
||||
default_factory=lambda: int(time.time()), description="Created time"
|
||||
)
|
||||
model: str = Field(..., description="Model name")
|
||||
choices: List[CompletionResponseStreamChoice] = Field(
|
||||
..., description="Completion response choices"
|
||||
)
|
||||
usage: UsageInfo = Field(..., description="Usage info")
|
||||
|
@ -145,6 +145,14 @@ class LLMModelAdapter(ABC):
|
||||
"""Whether the loaded model supports asynchronous calls"""
|
||||
return False
|
||||
|
||||
def support_generate_function(self) -> bool:
|
||||
"""Whether support generate function, if it is False, we will use
|
||||
generate_stream function.
|
||||
|
||||
Sometimes, we need to use generate function to get the result of the model.
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the generate stream function of the model"""
|
||||
raise NotImplementedError
|
||||
@ -153,6 +161,14 @@ class LLMModelAdapter(ABC):
|
||||
"""Get the asynchronous generate stream function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_generate_function(self, model, model_path: str):
|
||||
"""Get the generate function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_async_generate_function(self, model, model_path: str):
|
||||
"""Get the asynchronous generate function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> Optional[ConversationAdapter]:
|
||||
|
256
dbgpt/model/adapter/llama_cpp_adapter.py
Normal file
256
dbgpt/model/adapter/llama_cpp_adapter.py
Normal file
@ -0,0 +1,256 @@
|
||||
"""llama.cpp server adapter.
|
||||
|
||||
See more details:
|
||||
https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md
|
||||
|
||||
**Features:**
|
||||
* LLM inference of F16 and quantized models on GPU and CPU
|
||||
* Parallel decoding with multi-user support
|
||||
* Continuous batching
|
||||
|
||||
The llama.cpp server is pure C++ server, we need to use the llama-cpp-server-py-core
|
||||
to interact with it.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.model.adapter.base import ConversationAdapter, LLMModelAdapter
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.parameter import ModelParameters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
from llama_cpp_server_py_core import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionStreamResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
LlamaCppServer,
|
||||
ServerConfig,
|
||||
ServerProcess,
|
||||
)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Failed to import llama_cpp_server_py_core, please install it first by `pip install llama-cpp-server-py-core`"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LlamaServerParameters(ServerConfig, ModelParameters):
|
||||
lora_files: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"help": "Lora files path"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.model_name:
|
||||
self.model_alias = self.model_name
|
||||
|
||||
if self.model_path and not self.model_file:
|
||||
self.model_file = self.model_path
|
||||
|
||||
if self.lora_files and isinstance(self.lora_files, str):
|
||||
self.lora_files = self.lora_files.split(",") # type: ignore
|
||||
elif not self.lora_files:
|
||||
self.lora_files = [] # type: ignore
|
||||
|
||||
if self.model_path:
|
||||
self.model_hf_repo = None
|
||||
self.model_hf_file = None
|
||||
device = self.device or get_device()
|
||||
if device and device == "cuda" and not self.n_gpu_layers:
|
||||
# Set n_gpu_layers to a large number to use all layers
|
||||
logger.info("Set n_gpu_layers to a large number to use all layers")
|
||||
self.n_gpu_layers = 1000000000
|
||||
|
||||
|
||||
class LLamaServerModelAdapter(LLMModelAdapter):
|
||||
def new_adapter(self, **kwargs) -> "LLamaServerModelAdapter":
|
||||
return self.__class__()
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.LLAMA_CPP_SERVER
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> Type[LlamaServerParameters]:
|
||||
return LlamaServerParameters
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> Optional[ConversationAdapter]:
|
||||
return None
|
||||
|
||||
def load_from_params(self, params: LlamaServerParameters):
|
||||
server = ServerProcess(params)
|
||||
server.start(300)
|
||||
model_server = LlamaCppServer(server, params)
|
||||
return model_server, model_server
|
||||
|
||||
def support_generate_function(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
return generate_stream
|
||||
|
||||
def get_generate_function(self, model, model_path: str):
|
||||
return generate
|
||||
|
||||
|
||||
def generate_stream(
|
||||
model: LlamaCppServer,
|
||||
tokenizer: LlamaCppServer,
|
||||
params: Dict,
|
||||
device: str,
|
||||
context_len: int,
|
||||
):
|
||||
chat_model = params.get("chat_model", True)
|
||||
if chat_model is None:
|
||||
chat_model = True
|
||||
if chat_model:
|
||||
for out in chat_generate_stream(model, params):
|
||||
yield out
|
||||
else:
|
||||
req = _build_completion_request(params, stream=True)
|
||||
# resp = model.stream_completion(req)
|
||||
text = ""
|
||||
for r in model.stream_completion(req):
|
||||
text += r.content
|
||||
timings = r.timings
|
||||
usage = {
|
||||
"completion_tokens": r.tokens_predicted,
|
||||
"prompt_tokens": r.tokens_evaluated,
|
||||
"total_tokens": r.tokens_predicted + r.tokens_evaluated,
|
||||
}
|
||||
if timings:
|
||||
logger.debug(f"Timings: {timings}")
|
||||
yield ModelOutput(
|
||||
text=text,
|
||||
error_code=0,
|
||||
finish_reason=_parse_finish_reason(r.stop_type),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def chat_generate_stream(
|
||||
model: LlamaCppServer,
|
||||
params: Dict,
|
||||
):
|
||||
req = _build_chat_completion_request(params, stream=True)
|
||||
text = ""
|
||||
for r in model.stream_chat_completion(req):
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
# Check for empty 'choices' issue in Azure GPT-4o responses
|
||||
if r.choices[0] is not None and r.choices[0].delta is None:
|
||||
continue
|
||||
content = r.choices[0].delta.content
|
||||
finish_reason = _parse_finish_reason(r.choices[0].finish_reason)
|
||||
|
||||
if content is not None:
|
||||
content = r.choices[0].delta.content
|
||||
text += content
|
||||
yield ModelOutput(
|
||||
text=text, error_code=0, finish_reason=finish_reason, usage=r.usage
|
||||
)
|
||||
elif text and content is None:
|
||||
# Last response is empty, return the text
|
||||
yield ModelOutput(
|
||||
text=text, error_code=0, finish_reason=finish_reason, usage=r.usage
|
||||
)
|
||||
|
||||
|
||||
def _build_chat_completion_request(
|
||||
params: Dict, stream: bool = True
|
||||
) -> ChatCompletionRequest:
|
||||
from dbgpt.model.proxy.llms.proxy_model import parse_model_request
|
||||
|
||||
# LLamaCppServer does not need to parse the model
|
||||
model_request = parse_model_request(params, "", stream=stream)
|
||||
return ChatCompletionRequest(
|
||||
messages=model_request.to_common_messages(),
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
top_k=params.get("top_k"),
|
||||
max_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
stream=stream,
|
||||
presence_penalty=params.get("presence_penalty"),
|
||||
frequency_penalty=params.get("frequency_penalty"),
|
||||
user=params.get("user_name"),
|
||||
)
|
||||
|
||||
|
||||
def _build_completion_request(params: Dict, stream: bool = True) -> CompletionRequest:
|
||||
from dbgpt.model.proxy.llms.proxy_model import parse_model_request
|
||||
|
||||
# LLamaCppServer does not need to parse the model
|
||||
model_request = parse_model_request(params, "", stream=stream)
|
||||
prompt = params.get("prompt")
|
||||
if not prompt and model_request.messages:
|
||||
prompt = model_request.messages[-1].content
|
||||
if not prompt:
|
||||
raise ValueError("Prompt is required for non-chat model")
|
||||
|
||||
return CompletionRequest(
|
||||
prompt=prompt,
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
top_k=params.get("top_k"),
|
||||
n_predict=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
stream=stream,
|
||||
presence_penalty=params.get("presence_penalty"),
|
||||
frequency_penalty=params.get("frequency_penalty"),
|
||||
)
|
||||
|
||||
|
||||
def generate(
|
||||
model: LlamaCppServer,
|
||||
tokenizer: LlamaCppServer,
|
||||
params: Dict,
|
||||
device: str,
|
||||
context_len: int,
|
||||
):
|
||||
chat_model = params.get("chat_model", True)
|
||||
if chat_model is None:
|
||||
chat_model = True
|
||||
|
||||
if chat_model:
|
||||
req = _build_chat_completion_request(params, stream=False)
|
||||
resp = model.chat_completion(req)
|
||||
if not resp.choices or not resp.choices[0].message:
|
||||
raise ValueError("Response can't be empty")
|
||||
content = resp.choices[0].message.content
|
||||
return ModelOutput(
|
||||
text=content,
|
||||
error_code=0,
|
||||
finish_reason=_parse_finish_reason(resp.choices[0].finish_reason),
|
||||
usage=resp.usage,
|
||||
)
|
||||
|
||||
else:
|
||||
req = _build_completion_request(params, stream=False)
|
||||
resp = model.completion(req)
|
||||
content = resp.content
|
||||
usage = {
|
||||
"completion_tokens": resp.tokens_predicted,
|
||||
"prompt_tokens": resp.tokens_evaluated,
|
||||
"total_tokens": resp.tokens_predicted + resp.tokens_evaluated,
|
||||
}
|
||||
return ModelOutput(
|
||||
text=content,
|
||||
error_code=0,
|
||||
finish_reason=_parse_finish_reason(resp.stop_type),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def _parse_finish_reason(finish_reason: Optional[str]) -> Optional[str]:
|
||||
if finish_reason == "limit":
|
||||
return "length"
|
||||
elif finish_reason is not None:
|
||||
return "stop"
|
||||
return None
|
@ -130,6 +130,8 @@ class ModelLoader:
|
||||
return llm_adapter.load_from_params(model_params)
|
||||
elif model_type == ModelType.VLLM:
|
||||
return llm_adapter.load_from_params(model_params)
|
||||
elif model_type == ModelType.LLAMA_CPP_SERVER:
|
||||
return llm_adapter.load_from_params(model_params)
|
||||
else:
|
||||
raise Exception(f"Unkown model type {model_type}")
|
||||
|
||||
|
@ -42,6 +42,13 @@ def get_llm_model_adapter(
|
||||
from dbgpt.model.adapter.vllm_adapter import VLLMModelAdapterWrapper
|
||||
|
||||
return VLLMModelAdapterWrapper(conv_factory)
|
||||
if model_type == ModelType.LLAMA_CPP_SERVER:
|
||||
logger.info(
|
||||
"Current model type is llama_cpp_server, return LLamaServerModelAdapter"
|
||||
)
|
||||
from dbgpt.model.adapter.llama_cpp_adapter import LLamaServerModelAdapter
|
||||
|
||||
return LLamaServerModelAdapter()
|
||||
|
||||
# Import NewHFChatModelAdapter for it can be registered
|
||||
from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter
|
||||
|
@ -10,10 +10,11 @@ from dbgpt.util.parameter_utils import ParameterDescription
|
||||
|
||||
|
||||
class ModelType:
|
||||
""" "Type of model"""
|
||||
"""Type of model."""
|
||||
|
||||
HF = "huggingface"
|
||||
LLAMA_CPP = "llama.cpp"
|
||||
LLAMA_CPP_SERVER = "llama_cpp_server"
|
||||
PROXY = "proxy"
|
||||
VLLM = "vllm"
|
||||
# TODO, support more model type
|
||||
|
@ -27,6 +27,11 @@ from dbgpt.core.schema.api import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
@ -272,7 +277,13 @@ class APIServer(BaseComponent):
|
||||
worker_manager = self.get_worker_manager()
|
||||
id = f"chatcmpl-{shortuuid.random()}"
|
||||
finish_stream_events = []
|
||||
curr_usage = UsageInfo()
|
||||
last_usage = UsageInfo()
|
||||
for i in range(n):
|
||||
last_usage.prompt_tokens += curr_usage.prompt_tokens
|
||||
last_usage.completion_tokens += curr_usage.completion_tokens
|
||||
last_usage.total_tokens += curr_usage.total_tokens
|
||||
|
||||
# First chunk with role
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
@ -280,7 +291,10 @@ class APIServer(BaseComponent):
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
id=id,
|
||||
choices=[choice_data],
|
||||
model=model_name,
|
||||
usage=last_usage,
|
||||
)
|
||||
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {json_data}\n\n"
|
||||
@ -307,13 +321,28 @@ class APIServer(BaseComponent):
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
has_usage = False
|
||||
if model_output.usage:
|
||||
curr_usage = UsageInfo.model_validate(model_output.usage)
|
||||
has_usage = True
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=last_usage.prompt_tokens
|
||||
+ curr_usage.prompt_tokens,
|
||||
total_tokens=last_usage.total_tokens + curr_usage.total_tokens,
|
||||
completion_tokens=last_usage.completion_tokens
|
||||
+ curr_usage.completion_tokens,
|
||||
)
|
||||
else:
|
||||
has_usage = False
|
||||
usage = UsageInfo()
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
id=id, choices=[choice_data], model=model_name, usage=usage
|
||||
)
|
||||
if delta_text is None:
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
continue
|
||||
if not has_usage:
|
||||
continue
|
||||
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {json_data}\n\n"
|
||||
|
||||
@ -363,6 +392,118 @@ class APIServer(BaseComponent):
|
||||
|
||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||
|
||||
async def completion_stream_generator(
|
||||
self, request: CompletionRequest, params: Dict
|
||||
):
|
||||
worker_manager = self.get_worker_manager()
|
||||
id = f"cmpl-{shortuuid.random()}"
|
||||
finish_stream_events = []
|
||||
params["span_id"] = root_tracer.get_current_span_id()
|
||||
curr_usage = UsageInfo()
|
||||
last_usage = UsageInfo()
|
||||
for text in request.prompt:
|
||||
for i in range(request.n):
|
||||
params["prompt"] = text
|
||||
previous_text = ""
|
||||
last_usage.prompt_tokens += curr_usage.prompt_tokens
|
||||
last_usage.completion_tokens += curr_usage.completion_tokens
|
||||
last_usage.total_tokens += curr_usage.total_tokens
|
||||
|
||||
async for model_output in worker_manager.generate_stream(params):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
delta_text = decoded_unicode[len(previous_text) :]
|
||||
previous_text = (
|
||||
decoded_unicode
|
||||
if len(decoded_unicode) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
|
||||
if len(delta_text) == 0:
|
||||
delta_text = None
|
||||
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text=delta_text or "",
|
||||
# TODO: logprobs
|
||||
logprobs=None,
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
if model_output.usage:
|
||||
curr_usage = UsageInfo.model_validate(model_output.usage)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=last_usage.prompt_tokens
|
||||
+ curr_usage.prompt_tokens,
|
||||
total_tokens=last_usage.total_tokens
|
||||
+ curr_usage.total_tokens,
|
||||
completion_tokens=last_usage.completion_tokens
|
||||
+ curr_usage.completion_tokens,
|
||||
)
|
||||
else:
|
||||
usage = UsageInfo()
|
||||
chunk = CompletionStreamResponse(
|
||||
id=id,
|
||||
object="text_completion",
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
usage=UsageInfo.model_validate(usage),
|
||||
)
|
||||
if delta_text is None:
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
continue
|
||||
json_data = model_to_json(
|
||||
chunk, exclude_unset=True, ensure_ascii=False
|
||||
)
|
||||
yield f"data: {json_data}\n\n"
|
||||
last_usage = curr_usage
|
||||
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
||||
for finish_chunk in finish_stream_events:
|
||||
json_data = model_to_json(
|
||||
finish_chunk, exclude_unset=True, ensure_ascii=False
|
||||
)
|
||||
yield f"data: {json_data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def completion_generate(
|
||||
self, request: CompletionRequest, params: Dict[str, Any]
|
||||
):
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
choices = []
|
||||
completions = []
|
||||
for text in request.prompt:
|
||||
for i in range(request.n):
|
||||
params["prompt"] = text
|
||||
model_output = asyncio.create_task(worker_manager.generate(params))
|
||||
completions.append(model_output)
|
||||
try:
|
||||
all_tasks = await asyncio.gather(*completions)
|
||||
except Exception as e:
|
||||
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
||||
usage = UsageInfo()
|
||||
for i, model_output in enumerate(all_tasks):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
return create_error_response(model_output.error_code, model_output.text)
|
||||
choices.append(
|
||||
CompletionResponseChoice(
|
||||
index=i,
|
||||
text=model_output.text,
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
)
|
||||
if model_output.usage:
|
||||
task_usage = UsageInfo.model_validate(model_output.usage)
|
||||
for usage_key, usage_value in model_to_dict(task_usage).items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
return CompletionResponse(
|
||||
model=request.model, choices=choices, usage=UsageInfo.model_validate(usage)
|
||||
)
|
||||
|
||||
async def embeddings_generate(
|
||||
self,
|
||||
model: str,
|
||||
@ -485,6 +626,57 @@ async def create_chat_completion(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
||||
async def create_completion(
|
||||
request: CompletionRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
await api_server.get_model_instances_or_raise(request.model)
|
||||
error_check_ret = check_requests(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
if isinstance(request.prompt, str):
|
||||
request.prompt = [request.prompt]
|
||||
elif not isinstance(request.prompt, list):
|
||||
return create_error_response(
|
||||
ErrorCode.VALIDATION_TYPE_ERROR,
|
||||
"prompt must be a string or a list of strings",
|
||||
)
|
||||
elif isinstance(request.prompt, list) and not isinstance(request.prompt[0], str):
|
||||
return create_error_response(
|
||||
ErrorCode.VALIDATION_TYPE_ERROR,
|
||||
"prompt must be a string or a list of strings",
|
||||
)
|
||||
|
||||
params = {
|
||||
"model": request.model,
|
||||
"prompt": request.prompt,
|
||||
"chat_model": False,
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"stop": request.stop,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"echo": request.echo,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"user": request.user,
|
||||
# "use_beam_search": request.use_beam_search,
|
||||
# "beam_size": request.beam_size,
|
||||
}
|
||||
trace_kwargs = {
|
||||
"operation_name": "dbgpt.model.apiserver.create_completion",
|
||||
"metadata": {k: v for k, v in params.items() if v},
|
||||
}
|
||||
if request.stream:
|
||||
generator = api_server.completion_stream_generator(request, params)
|
||||
trace_generator = root_tracer.wrapper_async_stream(generator, **trace_kwargs)
|
||||
return StreamingResponse(trace_generator, media_type="text/event-stream")
|
||||
else:
|
||||
with root_tracer.start_span(**trace_kwargs):
|
||||
params["span_id"] = root_tracer.get_current_span_id()
|
||||
return await api_server.completion_generate(request, params)
|
||||
|
||||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
|
||||
async def create_embeddings(
|
||||
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
from dbgpt.model.base import WorkerApplyType
|
||||
from dbgpt.model.parameter import WorkerType
|
||||
@ -10,10 +10,14 @@ WORKER_MANAGER_SERVICE_NAME = "WorkerManager"
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
messages: List[ModelMessage]
|
||||
model: str
|
||||
messages: List[ModelMessage] = Field(
|
||||
default_factory=list, description="List of ModelMessage objects"
|
||||
)
|
||||
prompt: str = None
|
||||
temperature: float = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
max_new_tokens: int = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: List[int] = []
|
||||
@ -26,6 +30,10 @@ class PromptRequest(BaseModel):
|
||||
"""Message version, default to v2"""
|
||||
context: Dict[str, Any] = None
|
||||
"""Context information for the model"""
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
chat_model: Optional[bool] = True
|
||||
"""Whether to use chat model"""
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
|
@ -55,6 +55,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
model_type = self.llm_adapter.model_type()
|
||||
self.param_cls = self.llm_adapter.model_param_class(model_type)
|
||||
self._support_async = self.llm_adapter.support_async()
|
||||
self._support_generate_func = self.llm_adapter.support_generate_function()
|
||||
|
||||
logger.info(
|
||||
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
|
||||
@ -85,7 +86,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
model_path=self.model_path,
|
||||
model_type=model_type,
|
||||
)
|
||||
if not model_params.device:
|
||||
if hasattr(model_params, "device") and not model_params.device:
|
||||
model_params.device = get_device()
|
||||
logger.info(
|
||||
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
@ -129,7 +130,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.model:
|
||||
logger.warn("Model has been stopped!!")
|
||||
logger.warning("Model has been stopped!!")
|
||||
return
|
||||
del self.model
|
||||
del self.tokenizer
|
||||
@ -190,9 +191,40 @@ class DefaultModelWorker(ModelWorker):
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
output = None
|
||||
for out in self.generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
if self._support_generate_func:
|
||||
(
|
||||
params,
|
||||
model_context,
|
||||
generate_stream_func,
|
||||
model_span,
|
||||
) = self._prepare_generate_stream(
|
||||
params,
|
||||
span_operation_name="DefaultModelWorker_call.generate_func",
|
||||
is_stream=False,
|
||||
)
|
||||
previous_response = ""
|
||||
last_metrics = ModelInferenceMetrics.create_metrics()
|
||||
is_first_generate = True
|
||||
output = generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
)
|
||||
(
|
||||
model_output,
|
||||
incremental_output,
|
||||
output_str,
|
||||
current_metrics,
|
||||
) = self._handle_output(
|
||||
output,
|
||||
previous_response,
|
||||
model_context,
|
||||
last_metrics,
|
||||
is_first_generate,
|
||||
)
|
||||
return model_output
|
||||
else:
|
||||
for out in self.generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
return _try_to_count_token(prompt, self.tokenizer, self.model)
|
||||
@ -277,12 +309,45 @@ class DefaultModelWorker(ModelWorker):
|
||||
span.end(metadata={"error": output.to_dict()})
|
||||
|
||||
async def async_generate(self, params: Dict) -> ModelOutput:
|
||||
output = None
|
||||
async for out in self.async_generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
if self._support_generate_func:
|
||||
(
|
||||
params,
|
||||
model_context,
|
||||
generate_stream_func,
|
||||
model_span,
|
||||
) = self._prepare_generate_stream(
|
||||
params,
|
||||
span_operation_name="DefaultModelWorker_call.generate_func",
|
||||
is_stream=False,
|
||||
)
|
||||
previous_response = ""
|
||||
last_metrics = ModelInferenceMetrics.create_metrics()
|
||||
is_first_generate = True
|
||||
output = await generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
)
|
||||
(
|
||||
model_output,
|
||||
incremental_output,
|
||||
output_str,
|
||||
current_metrics,
|
||||
) = self._handle_output(
|
||||
output,
|
||||
previous_response,
|
||||
model_context,
|
||||
last_metrics,
|
||||
is_first_generate,
|
||||
)
|
||||
return model_output
|
||||
else:
|
||||
output = None
|
||||
async for out in self.async_generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def _prepare_generate_stream(self, params: Dict, span_operation_name: str):
|
||||
def _prepare_generate_stream(
|
||||
self, params: Dict, span_operation_name: str, is_stream=True
|
||||
):
|
||||
params, model_context = self.llm_adapter.model_adaptation(
|
||||
params,
|
||||
self.model_name,
|
||||
@ -290,29 +355,48 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.tokenizer,
|
||||
prompt_template=self.ml.prompt_template,
|
||||
)
|
||||
stream_type = ""
|
||||
if self.support_async():
|
||||
generate_stream_func = self.llm_adapter.get_async_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
stream_type = "async "
|
||||
logger.info(
|
||||
"current generate stream function is asynchronous stream function"
|
||||
)
|
||||
if not is_stream and self.llm_adapter.support_generate_function():
|
||||
func = self.llm_adapter.get_generate_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "async generate"
|
||||
logger.info(
|
||||
"current generate function is asynchronous generate function"
|
||||
)
|
||||
else:
|
||||
func = self.llm_adapter.get_async_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "async generate stream"
|
||||
logger.info(
|
||||
"current generate stream function is asynchronous generate stream function"
|
||||
)
|
||||
else:
|
||||
generate_stream_func = self.llm_adapter.get_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
if not is_stream and self.llm_adapter.support_generate_function():
|
||||
func = self.llm_adapter.get_generate_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "generate"
|
||||
logger.info(
|
||||
"current generate function is synchronous generate function"
|
||||
)
|
||||
else:
|
||||
func = self.llm_adapter.get_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "generate stream"
|
||||
logger.info(
|
||||
"current generate stream function is synchronous generate stream function"
|
||||
)
|
||||
str_prompt = params.get("prompt")
|
||||
if not str_prompt:
|
||||
str_prompt = params.get("string_prompt")
|
||||
print(
|
||||
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n"
|
||||
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{func_type} output:\n"
|
||||
)
|
||||
|
||||
generate_stream_func_str_name = "{}.{}".format(
|
||||
generate_stream_func.__module__, generate_stream_func.__name__
|
||||
)
|
||||
generate_func_str_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
|
||||
span_params = {k: v for k, v in params.items()}
|
||||
if "messages" in span_params:
|
||||
@ -323,7 +407,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
metadata = {
|
||||
"is_async_func": self.support_async(),
|
||||
"llm_adapter": str(self.llm_adapter),
|
||||
"generate_stream_func": generate_stream_func_str_name,
|
||||
"generate_func": generate_func_str_name,
|
||||
}
|
||||
metadata.update(span_params)
|
||||
metadata.update(model_context)
|
||||
@ -331,7 +415,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
|
||||
model_span = root_tracer.start_span(span_operation_name, metadata=metadata)
|
||||
|
||||
return params, model_context, generate_stream_func, model_span
|
||||
return params, model_context, func, model_span
|
||||
|
||||
def _handle_output(
|
||||
self,
|
||||
|
@ -89,6 +89,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return ModelOutput(**response.json())
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
@ -106,6 +108,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json={"prompt": prompt},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
@ -123,6 +127,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return ModelMetadata.from_dict(response.json())
|
||||
|
||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
@ -141,6 +147,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
@ -156,6 +164,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def _get_trace_headers(self):
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""
|
||||
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict
|
||||
@ -62,11 +63,11 @@ class LlamaCppModel:
|
||||
self.model.__del__()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, model_path, model_params: LlamaCppModelParameters):
|
||||
def from_pretrained(cls, model_path, model_params: LlamaCppModelParameters):
|
||||
Llama = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).Llama
|
||||
LlamaCache = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).LlamaCache
|
||||
|
||||
result = self()
|
||||
result = cls()
|
||||
cache_capacity = 0
|
||||
cache_capacity_str = model_params.cache_capacity
|
||||
if cache_capacity_str is not None:
|
||||
|
40
docs/docs/installation/advanced_usage/Llamacpp_server.md
Normal file
40
docs/docs/installation/advanced_usage/Llamacpp_server.md
Normal file
@ -0,0 +1,40 @@
|
||||
# LLama.cpp Server
|
||||
|
||||
DB-GPT supports native [llama.cpp server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md),
|
||||
which supports concurrent requests and continuous batching inference.
|
||||
|
||||
|
||||
## Install dependencies
|
||||
|
||||
```bash
|
||||
pip install -e ".[llama_cpp_server]"
|
||||
```
|
||||
If you want to accelerate the inference speed, and you have a GPU, you can install the following dependencies:
|
||||
|
||||
```bash
|
||||
CMAKE_ARGS="-DGGML_CUDA=ON" pip install -e ".[llama_cpp_server]"
|
||||
```
|
||||
|
||||
## Download the model
|
||||
|
||||
Here, we use the `qwen2.5-0.5b-instruct` model as an example. You can download the model from the [Huggingface](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF).
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/qwen2.5-0.5b-instruct-q4_k_m.gguf?download=true -O /tmp/qwen2.5-0.5b-instruct-q4_k_m.gguf
|
||||
````
|
||||
|
||||
## Modify configuration file
|
||||
|
||||
In the `.env` configuration file, modify the inference type of the model to start `llama.cpp` inference.
|
||||
|
||||
```bash
|
||||
LLM_MODEL=qwen2.5-0.5b-instruct
|
||||
LLM_MODEL_PATH=/tmp/qwen2.5-0.5b-instruct-q4_k_m.gguf
|
||||
MODEL_TYPE=llama_cpp_server
|
||||
```
|
||||
|
||||
## Start the DB-GPT server
|
||||
|
||||
```bash
|
||||
python dbgpt/app/dbgpt_server.py
|
||||
```
|
@ -271,6 +271,10 @@ const sidebars = {
|
||||
type: 'doc',
|
||||
id: 'installation/advanced_usage/vLLM_inference',
|
||||
},
|
||||
{
|
||||
type: 'doc',
|
||||
id: 'installation/advanced_usage/Llamacpp_server',
|
||||
},
|
||||
{
|
||||
type: 'doc',
|
||||
id: 'installation/advanced_usage/OpenAI_SDK_call',
|
||||
|
5
setup.py
5
setup.py
@ -556,7 +556,10 @@ def llama_cpp_requires():
|
||||
"""
|
||||
pip install "dbgpt[llama_cpp]"
|
||||
"""
|
||||
setup_spec.extras["llama_cpp"] = ["llama-cpp-python"]
|
||||
setup_spec.extras["llama_cpp_server"] = ["llama-cpp-server-py"]
|
||||
setup_spec.extras["llama_cpp"] = setup_spec.extras["llama_cpp_server"] + [
|
||||
"llama-cpp-python"
|
||||
]
|
||||
llama_cpp_python_cuda_requires()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user