mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
feat(model): Support llama.cpp server deploy (#2263)
This commit is contained in:
parent
576da34e92
commit
0b2af2e9a2
@ -65,6 +65,14 @@ class APIChatCompletionRequest(BaseModel):
|
|||||||
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
|
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(BaseModel):
|
||||||
|
"""Usage info entity."""
|
||||||
|
|
||||||
|
prompt_tokens: int = Field(0, description="Prompt tokens")
|
||||||
|
total_tokens: int = Field(0, description="Total tokens")
|
||||||
|
completion_tokens: Optional[int] = Field(0, description="Completion tokens")
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class DeltaMessage(BaseModel):
|
||||||
"""Delta message entity for chat completion response."""
|
"""Delta message entity for chat completion response."""
|
||||||
|
|
||||||
@ -95,6 +103,7 @@ class ChatCompletionStreamResponse(BaseModel):
|
|||||||
choices: List[ChatCompletionResponseStreamChoice] = Field(
|
choices: List[ChatCompletionResponseStreamChoice] = Field(
|
||||||
..., description="Chat completion response choices"
|
..., description="Chat completion response choices"
|
||||||
)
|
)
|
||||||
|
usage: UsageInfo = Field(..., description="Usage info")
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
@ -104,14 +113,6 @@ class ChatMessage(BaseModel):
|
|||||||
content: str = Field(..., description="Content of the message")
|
content: str = Field(..., description="Content of the message")
|
||||||
|
|
||||||
|
|
||||||
class UsageInfo(BaseModel):
|
|
||||||
"""Usage info entity."""
|
|
||||||
|
|
||||||
prompt_tokens: int = Field(0, description="Prompt tokens")
|
|
||||||
total_tokens: int = Field(0, description="Total tokens")
|
|
||||||
completion_tokens: Optional[int] = Field(0, description="Completion tokens")
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
"""Chat completion response choice entity."""
|
"""Chat completion response choice entity."""
|
||||||
|
|
||||||
@ -256,3 +257,157 @@ class ErrorCode(IntEnum):
|
|||||||
GRADIO_STREAM_UNKNOWN_ERROR = 50004
|
GRADIO_STREAM_UNKNOWN_ERROR = 50004
|
||||||
CONTROLLER_NO_WORKER = 50005
|
CONTROLLER_NO_WORKER = 50005
|
||||||
CONTROLLER_WORKER_TIMEOUT = 50006
|
CONTROLLER_WORKER_TIMEOUT = 50006
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequest(BaseModel):
|
||||||
|
"""Completion request entity."""
|
||||||
|
|
||||||
|
model: str = Field(..., description="Model name")
|
||||||
|
prompt: Union[str, List[Any]] = Field(
|
||||||
|
...,
|
||||||
|
description="Provide the prompt for this completion as a string or as an "
|
||||||
|
"array of strings or numbers representing tokens",
|
||||||
|
)
|
||||||
|
suffix: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Suffix to append to the completion. If provided, the model will "
|
||||||
|
"stop generating upon reaching this suffix",
|
||||||
|
)
|
||||||
|
temperature: Optional[float] = Field(
|
||||||
|
0.8,
|
||||||
|
description="Adjust the randomness of the generated text. Default: `0.8`",
|
||||||
|
)
|
||||||
|
n: Optional[int] = Field(
|
||||||
|
1,
|
||||||
|
description="Number of completions to generate. Default: `1`",
|
||||||
|
)
|
||||||
|
max_tokens: Optional[int] = Field(
|
||||||
|
16,
|
||||||
|
description="The maximum number of tokens that can be generated in the "
|
||||||
|
"completion. Default: `16`",
|
||||||
|
)
|
||||||
|
stop: Optional[Union[str, List[str]]] = Field(
|
||||||
|
None,
|
||||||
|
description="Up to 4 sequences where the API will stop generating further "
|
||||||
|
"tokens. The returned text will not contain the stop sequence.",
|
||||||
|
)
|
||||||
|
stream: Optional[bool] = Field(
|
||||||
|
False,
|
||||||
|
description="Whether to stream back partial completions. Default: `False`",
|
||||||
|
)
|
||||||
|
top_p: Optional[float] = Field(
|
||||||
|
1.0,
|
||||||
|
description="Limit the next token selection to a subset of tokens with a "
|
||||||
|
"cumulative probability above a threshold P. Default: `1.0`",
|
||||||
|
)
|
||||||
|
top_k: Optional[int] = Field(
|
||||||
|
-1,
|
||||||
|
description="Limit the next token selection to the K most probable tokens. "
|
||||||
|
"Default: `-1`",
|
||||||
|
)
|
||||||
|
logprobs: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description="Modify the likelihood of specified tokens appearing in the "
|
||||||
|
"completion.",
|
||||||
|
)
|
||||||
|
echo: Optional[bool] = Field(
|
||||||
|
False, description="Echo back the prompt in addition to the completion"
|
||||||
|
)
|
||||||
|
presence_penalty: Optional[float] = Field(
|
||||||
|
0.0,
|
||||||
|
description="Number between -2.0 and 2.0. Positive values penalize new tokens "
|
||||||
|
"based on whether they appear in the text so far, increasing the model's "
|
||||||
|
"likelihood to talk about new topics.",
|
||||||
|
)
|
||||||
|
frequency_penalty: Optional[float] = Field(
|
||||||
|
0.0,
|
||||||
|
description="Number between -2.0 and 2.0. Positive values penalize new tokens "
|
||||||
|
"based on their existing frequency in the text so far, decreasing the model's "
|
||||||
|
"likelihood to repeat the same line verbatim.",
|
||||||
|
)
|
||||||
|
user: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="A unique identifier representing your end-user, which can help "
|
||||||
|
"OpenAI to monitor and detect abuse.",
|
||||||
|
)
|
||||||
|
use_beam_search: Optional[bool] = False
|
||||||
|
best_of: Optional[int] = Field(
|
||||||
|
1,
|
||||||
|
description='Generates best_of completions server-side and returns the "best" '
|
||||||
|
"(the one with the highest log probability per token). Results cannot be "
|
||||||
|
"streamed. When used with n, best_of controls the number of candidate "
|
||||||
|
"completions and n specifies how many to return – best_of must be greater than "
|
||||||
|
"n.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LogProbs(BaseModel):
|
||||||
|
"""Logprobs entity."""
|
||||||
|
|
||||||
|
text_offset: List[int] = Field(default_factory=list, description="Text offset")
|
||||||
|
token_logprobs: List[Optional[float]] = Field(
|
||||||
|
default_factory=list, description="Token logprobs"
|
||||||
|
)
|
||||||
|
tokens: List[str] = Field(default_factory=list, description="Tokens")
|
||||||
|
top_logprobs: List[Optional[Dict[str, float]]] = Field(
|
||||||
|
default_factory=list, description="Top logprobs"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseChoice(BaseModel):
|
||||||
|
"""Completion response choice entity."""
|
||||||
|
|
||||||
|
index: int = Field(..., description="Choice index")
|
||||||
|
text: str = Field(..., description="Text")
|
||||||
|
logprobs: Optional[LogProbs] = Field(None, description="Logprobs")
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = Field(
|
||||||
|
None, description="The reason the model stopped generating tokens."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponse(BaseModel):
|
||||||
|
"""Completion response entity."""
|
||||||
|
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid1())}")
|
||||||
|
object: str = Field(
|
||||||
|
"text_completion",
|
||||||
|
description="The object type, which is always 'text_completion'",
|
||||||
|
)
|
||||||
|
created: int = Field(
|
||||||
|
default_factory=lambda: int(time.time()), description="Created time"
|
||||||
|
)
|
||||||
|
model: str = Field(..., description="Model name")
|
||||||
|
choices: List[CompletionResponseChoice] = Field(
|
||||||
|
...,
|
||||||
|
description="The list of completion choices the model generated for the input "
|
||||||
|
"prompt.",
|
||||||
|
)
|
||||||
|
usage: UsageInfo = Field(..., description="Usage info")
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseStreamChoice(BaseModel):
|
||||||
|
"""Completion response choice entity."""
|
||||||
|
|
||||||
|
index: int = Field(..., description="Choice index")
|
||||||
|
text: str = Field(..., description="Text")
|
||||||
|
logprobs: Optional[LogProbs] = Field(None, description="Logprobs")
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = Field(
|
||||||
|
None, description="The reason the model stopped generating tokens."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionStreamResponse(BaseModel):
|
||||||
|
"""Completion stream response entity."""
|
||||||
|
|
||||||
|
id: str = Field(
|
||||||
|
default_factory=lambda: f"cmpl-{str(uuid.uuid1())}", description="Stream ID"
|
||||||
|
)
|
||||||
|
object: str = Field("text_completion", description="Object type")
|
||||||
|
created: int = Field(
|
||||||
|
default_factory=lambda: int(time.time()), description="Created time"
|
||||||
|
)
|
||||||
|
model: str = Field(..., description="Model name")
|
||||||
|
choices: List[CompletionResponseStreamChoice] = Field(
|
||||||
|
..., description="Completion response choices"
|
||||||
|
)
|
||||||
|
usage: UsageInfo = Field(..., description="Usage info")
|
||||||
|
@ -145,6 +145,14 @@ class LLMModelAdapter(ABC):
|
|||||||
"""Whether the loaded model supports asynchronous calls"""
|
"""Whether the loaded model supports asynchronous calls"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def support_generate_function(self) -> bool:
|
||||||
|
"""Whether support generate function, if it is False, we will use
|
||||||
|
generate_stream function.
|
||||||
|
|
||||||
|
Sometimes, we need to use generate function to get the result of the model.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
def get_generate_stream_function(self, model, model_path: str):
|
def get_generate_stream_function(self, model, model_path: str):
|
||||||
"""Get the generate stream function of the model"""
|
"""Get the generate stream function of the model"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -153,6 +161,14 @@ class LLMModelAdapter(ABC):
|
|||||||
"""Get the asynchronous generate stream function of the model"""
|
"""Get the asynchronous generate stream function of the model"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_generate_function(self, model, model_path: str):
|
||||||
|
"""Get the generate function of the model"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_async_generate_function(self, model, model_path: str):
|
||||||
|
"""Get the asynchronous generate function of the model"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_default_conv_template(
|
def get_default_conv_template(
|
||||||
self, model_name: str, model_path: str
|
self, model_name: str, model_path: str
|
||||||
) -> Optional[ConversationAdapter]:
|
) -> Optional[ConversationAdapter]:
|
||||||
|
256
dbgpt/model/adapter/llama_cpp_adapter.py
Normal file
256
dbgpt/model/adapter/llama_cpp_adapter.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
"""llama.cpp server adapter.
|
||||||
|
|
||||||
|
See more details:
|
||||||
|
https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
* LLM inference of F16 and quantized models on GPU and CPU
|
||||||
|
* Parallel decoding with multi-user support
|
||||||
|
* Continuous batching
|
||||||
|
|
||||||
|
The llama.cpp server is pure C++ server, we need to use the llama-cpp-server-py-core
|
||||||
|
to interact with it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional, Type
|
||||||
|
|
||||||
|
from dbgpt.configs.model_config import get_device
|
||||||
|
from dbgpt.core import ModelOutput
|
||||||
|
from dbgpt.model.adapter.base import ConversationAdapter, LLMModelAdapter
|
||||||
|
from dbgpt.model.base import ModelType
|
||||||
|
from dbgpt.model.parameter import ModelParameters
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
try:
|
||||||
|
from llama_cpp_server_py_core import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
LlamaCppServer,
|
||||||
|
ServerConfig,
|
||||||
|
ServerProcess,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
"Failed to import llama_cpp_server_py_core, please install it first by `pip install llama-cpp-server-py-core`"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LlamaServerParameters(ServerConfig, ModelParameters):
|
||||||
|
lora_files: Optional[str] = dataclasses.field(
|
||||||
|
default=None, metadata={"help": "Lora files path"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.model_name:
|
||||||
|
self.model_alias = self.model_name
|
||||||
|
|
||||||
|
if self.model_path and not self.model_file:
|
||||||
|
self.model_file = self.model_path
|
||||||
|
|
||||||
|
if self.lora_files and isinstance(self.lora_files, str):
|
||||||
|
self.lora_files = self.lora_files.split(",") # type: ignore
|
||||||
|
elif not self.lora_files:
|
||||||
|
self.lora_files = [] # type: ignore
|
||||||
|
|
||||||
|
if self.model_path:
|
||||||
|
self.model_hf_repo = None
|
||||||
|
self.model_hf_file = None
|
||||||
|
device = self.device or get_device()
|
||||||
|
if device and device == "cuda" and not self.n_gpu_layers:
|
||||||
|
# Set n_gpu_layers to a large number to use all layers
|
||||||
|
logger.info("Set n_gpu_layers to a large number to use all layers")
|
||||||
|
self.n_gpu_layers = 1000000000
|
||||||
|
|
||||||
|
|
||||||
|
class LLamaServerModelAdapter(LLMModelAdapter):
|
||||||
|
def new_adapter(self, **kwargs) -> "LLamaServerModelAdapter":
|
||||||
|
return self.__class__()
|
||||||
|
|
||||||
|
def model_type(self) -> str:
|
||||||
|
return ModelType.LLAMA_CPP_SERVER
|
||||||
|
|
||||||
|
def model_param_class(self, model_type: str = None) -> Type[LlamaServerParameters]:
|
||||||
|
return LlamaServerParameters
|
||||||
|
|
||||||
|
def get_default_conv_template(
|
||||||
|
self, model_name: str, model_path: str
|
||||||
|
) -> Optional[ConversationAdapter]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_from_params(self, params: LlamaServerParameters):
|
||||||
|
server = ServerProcess(params)
|
||||||
|
server.start(300)
|
||||||
|
model_server = LlamaCppServer(server, params)
|
||||||
|
return model_server, model_server
|
||||||
|
|
||||||
|
def support_generate_function(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_generate_stream_function(self, model, model_path: str):
|
||||||
|
return generate_stream
|
||||||
|
|
||||||
|
def get_generate_function(self, model, model_path: str):
|
||||||
|
return generate
|
||||||
|
|
||||||
|
|
||||||
|
def generate_stream(
|
||||||
|
model: LlamaCppServer,
|
||||||
|
tokenizer: LlamaCppServer,
|
||||||
|
params: Dict,
|
||||||
|
device: str,
|
||||||
|
context_len: int,
|
||||||
|
):
|
||||||
|
chat_model = params.get("chat_model", True)
|
||||||
|
if chat_model is None:
|
||||||
|
chat_model = True
|
||||||
|
if chat_model:
|
||||||
|
for out in chat_generate_stream(model, params):
|
||||||
|
yield out
|
||||||
|
else:
|
||||||
|
req = _build_completion_request(params, stream=True)
|
||||||
|
# resp = model.stream_completion(req)
|
||||||
|
text = ""
|
||||||
|
for r in model.stream_completion(req):
|
||||||
|
text += r.content
|
||||||
|
timings = r.timings
|
||||||
|
usage = {
|
||||||
|
"completion_tokens": r.tokens_predicted,
|
||||||
|
"prompt_tokens": r.tokens_evaluated,
|
||||||
|
"total_tokens": r.tokens_predicted + r.tokens_evaluated,
|
||||||
|
}
|
||||||
|
if timings:
|
||||||
|
logger.debug(f"Timings: {timings}")
|
||||||
|
yield ModelOutput(
|
||||||
|
text=text,
|
||||||
|
error_code=0,
|
||||||
|
finish_reason=_parse_finish_reason(r.stop_type),
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def chat_generate_stream(
|
||||||
|
model: LlamaCppServer,
|
||||||
|
params: Dict,
|
||||||
|
):
|
||||||
|
req = _build_chat_completion_request(params, stream=True)
|
||||||
|
text = ""
|
||||||
|
for r in model.stream_chat_completion(req):
|
||||||
|
if len(r.choices) == 0:
|
||||||
|
continue
|
||||||
|
# Check for empty 'choices' issue in Azure GPT-4o responses
|
||||||
|
if r.choices[0] is not None and r.choices[0].delta is None:
|
||||||
|
continue
|
||||||
|
content = r.choices[0].delta.content
|
||||||
|
finish_reason = _parse_finish_reason(r.choices[0].finish_reason)
|
||||||
|
|
||||||
|
if content is not None:
|
||||||
|
content = r.choices[0].delta.content
|
||||||
|
text += content
|
||||||
|
yield ModelOutput(
|
||||||
|
text=text, error_code=0, finish_reason=finish_reason, usage=r.usage
|
||||||
|
)
|
||||||
|
elif text and content is None:
|
||||||
|
# Last response is empty, return the text
|
||||||
|
yield ModelOutput(
|
||||||
|
text=text, error_code=0, finish_reason=finish_reason, usage=r.usage
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_chat_completion_request(
|
||||||
|
params: Dict, stream: bool = True
|
||||||
|
) -> ChatCompletionRequest:
|
||||||
|
from dbgpt.model.proxy.llms.proxy_model import parse_model_request
|
||||||
|
|
||||||
|
# LLamaCppServer does not need to parse the model
|
||||||
|
model_request = parse_model_request(params, "", stream=stream)
|
||||||
|
return ChatCompletionRequest(
|
||||||
|
messages=model_request.to_common_messages(),
|
||||||
|
temperature=params.get("temperature"),
|
||||||
|
top_p=params.get("top_p"),
|
||||||
|
top_k=params.get("top_k"),
|
||||||
|
max_tokens=params.get("max_new_tokens"),
|
||||||
|
stop=params.get("stop"),
|
||||||
|
stream=stream,
|
||||||
|
presence_penalty=params.get("presence_penalty"),
|
||||||
|
frequency_penalty=params.get("frequency_penalty"),
|
||||||
|
user=params.get("user_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_completion_request(params: Dict, stream: bool = True) -> CompletionRequest:
|
||||||
|
from dbgpt.model.proxy.llms.proxy_model import parse_model_request
|
||||||
|
|
||||||
|
# LLamaCppServer does not need to parse the model
|
||||||
|
model_request = parse_model_request(params, "", stream=stream)
|
||||||
|
prompt = params.get("prompt")
|
||||||
|
if not prompt and model_request.messages:
|
||||||
|
prompt = model_request.messages[-1].content
|
||||||
|
if not prompt:
|
||||||
|
raise ValueError("Prompt is required for non-chat model")
|
||||||
|
|
||||||
|
return CompletionRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=params.get("temperature"),
|
||||||
|
top_p=params.get("top_p"),
|
||||||
|
top_k=params.get("top_k"),
|
||||||
|
n_predict=params.get("max_new_tokens"),
|
||||||
|
stop=params.get("stop"),
|
||||||
|
stream=stream,
|
||||||
|
presence_penalty=params.get("presence_penalty"),
|
||||||
|
frequency_penalty=params.get("frequency_penalty"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
model: LlamaCppServer,
|
||||||
|
tokenizer: LlamaCppServer,
|
||||||
|
params: Dict,
|
||||||
|
device: str,
|
||||||
|
context_len: int,
|
||||||
|
):
|
||||||
|
chat_model = params.get("chat_model", True)
|
||||||
|
if chat_model is None:
|
||||||
|
chat_model = True
|
||||||
|
|
||||||
|
if chat_model:
|
||||||
|
req = _build_chat_completion_request(params, stream=False)
|
||||||
|
resp = model.chat_completion(req)
|
||||||
|
if not resp.choices or not resp.choices[0].message:
|
||||||
|
raise ValueError("Response can't be empty")
|
||||||
|
content = resp.choices[0].message.content
|
||||||
|
return ModelOutput(
|
||||||
|
text=content,
|
||||||
|
error_code=0,
|
||||||
|
finish_reason=_parse_finish_reason(resp.choices[0].finish_reason),
|
||||||
|
usage=resp.usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
req = _build_completion_request(params, stream=False)
|
||||||
|
resp = model.completion(req)
|
||||||
|
content = resp.content
|
||||||
|
usage = {
|
||||||
|
"completion_tokens": resp.tokens_predicted,
|
||||||
|
"prompt_tokens": resp.tokens_evaluated,
|
||||||
|
"total_tokens": resp.tokens_predicted + resp.tokens_evaluated,
|
||||||
|
}
|
||||||
|
return ModelOutput(
|
||||||
|
text=content,
|
||||||
|
error_code=0,
|
||||||
|
finish_reason=_parse_finish_reason(resp.stop_type),
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_finish_reason(finish_reason: Optional[str]) -> Optional[str]:
|
||||||
|
if finish_reason == "limit":
|
||||||
|
return "length"
|
||||||
|
elif finish_reason is not None:
|
||||||
|
return "stop"
|
||||||
|
return None
|
@ -130,6 +130,8 @@ class ModelLoader:
|
|||||||
return llm_adapter.load_from_params(model_params)
|
return llm_adapter.load_from_params(model_params)
|
||||||
elif model_type == ModelType.VLLM:
|
elif model_type == ModelType.VLLM:
|
||||||
return llm_adapter.load_from_params(model_params)
|
return llm_adapter.load_from_params(model_params)
|
||||||
|
elif model_type == ModelType.LLAMA_CPP_SERVER:
|
||||||
|
return llm_adapter.load_from_params(model_params)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unkown model type {model_type}")
|
raise Exception(f"Unkown model type {model_type}")
|
||||||
|
|
||||||
|
@ -42,6 +42,13 @@ def get_llm_model_adapter(
|
|||||||
from dbgpt.model.adapter.vllm_adapter import VLLMModelAdapterWrapper
|
from dbgpt.model.adapter.vllm_adapter import VLLMModelAdapterWrapper
|
||||||
|
|
||||||
return VLLMModelAdapterWrapper(conv_factory)
|
return VLLMModelAdapterWrapper(conv_factory)
|
||||||
|
if model_type == ModelType.LLAMA_CPP_SERVER:
|
||||||
|
logger.info(
|
||||||
|
"Current model type is llama_cpp_server, return LLamaServerModelAdapter"
|
||||||
|
)
|
||||||
|
from dbgpt.model.adapter.llama_cpp_adapter import LLamaServerModelAdapter
|
||||||
|
|
||||||
|
return LLamaServerModelAdapter()
|
||||||
|
|
||||||
# Import NewHFChatModelAdapter for it can be registered
|
# Import NewHFChatModelAdapter for it can be registered
|
||||||
from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter
|
from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter
|
||||||
|
@ -10,10 +10,11 @@ from dbgpt.util.parameter_utils import ParameterDescription
|
|||||||
|
|
||||||
|
|
||||||
class ModelType:
|
class ModelType:
|
||||||
""" "Type of model"""
|
"""Type of model."""
|
||||||
|
|
||||||
HF = "huggingface"
|
HF = "huggingface"
|
||||||
LLAMA_CPP = "llama.cpp"
|
LLAMA_CPP = "llama.cpp"
|
||||||
|
LLAMA_CPP_SERVER = "llama_cpp_server"
|
||||||
PROXY = "proxy"
|
PROXY = "proxy"
|
||||||
VLLM = "vllm"
|
VLLM = "vllm"
|
||||||
# TODO, support more model type
|
# TODO, support more model type
|
||||||
|
@ -27,6 +27,11 @@ from dbgpt.core.schema.api import (
|
|||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse,
|
ChatCompletionStreamResponse,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseChoice,
|
||||||
|
CompletionResponseStreamChoice,
|
||||||
|
CompletionStreamResponse,
|
||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
EmbeddingsRequest,
|
EmbeddingsRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
@ -272,7 +277,13 @@ class APIServer(BaseComponent):
|
|||||||
worker_manager = self.get_worker_manager()
|
worker_manager = self.get_worker_manager()
|
||||||
id = f"chatcmpl-{shortuuid.random()}"
|
id = f"chatcmpl-{shortuuid.random()}"
|
||||||
finish_stream_events = []
|
finish_stream_events = []
|
||||||
|
curr_usage = UsageInfo()
|
||||||
|
last_usage = UsageInfo()
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
|
last_usage.prompt_tokens += curr_usage.prompt_tokens
|
||||||
|
last_usage.completion_tokens += curr_usage.completion_tokens
|
||||||
|
last_usage.total_tokens += curr_usage.total_tokens
|
||||||
|
|
||||||
# First chunk with role
|
# First chunk with role
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i,
|
index=i,
|
||||||
@ -280,7 +291,10 @@ class APIServer(BaseComponent):
|
|||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=id, choices=[choice_data], model=model_name
|
id=id,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name,
|
||||||
|
usage=last_usage,
|
||||||
)
|
)
|
||||||
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
||||||
yield f"data: {json_data}\n\n"
|
yield f"data: {json_data}\n\n"
|
||||||
@ -307,12 +321,27 @@ class APIServer(BaseComponent):
|
|||||||
delta=DeltaMessage(content=delta_text),
|
delta=DeltaMessage(content=delta_text),
|
||||||
finish_reason=model_output.finish_reason,
|
finish_reason=model_output.finish_reason,
|
||||||
)
|
)
|
||||||
|
has_usage = False
|
||||||
|
if model_output.usage:
|
||||||
|
curr_usage = UsageInfo.model_validate(model_output.usage)
|
||||||
|
has_usage = True
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=last_usage.prompt_tokens
|
||||||
|
+ curr_usage.prompt_tokens,
|
||||||
|
total_tokens=last_usage.total_tokens + curr_usage.total_tokens,
|
||||||
|
completion_tokens=last_usage.completion_tokens
|
||||||
|
+ curr_usage.completion_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
has_usage = False
|
||||||
|
usage = UsageInfo()
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=id, choices=[choice_data], model=model_name
|
id=id, choices=[choice_data], model=model_name, usage=usage
|
||||||
)
|
)
|
||||||
if delta_text is None:
|
if delta_text is None:
|
||||||
if model_output.finish_reason is not None:
|
if model_output.finish_reason is not None:
|
||||||
finish_stream_events.append(chunk)
|
finish_stream_events.append(chunk)
|
||||||
|
if not has_usage:
|
||||||
continue
|
continue
|
||||||
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
||||||
yield f"data: {json_data}\n\n"
|
yield f"data: {json_data}\n\n"
|
||||||
@ -363,6 +392,118 @@ class APIServer(BaseComponent):
|
|||||||
|
|
||||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||||
|
|
||||||
|
async def completion_stream_generator(
|
||||||
|
self, request: CompletionRequest, params: Dict
|
||||||
|
):
|
||||||
|
worker_manager = self.get_worker_manager()
|
||||||
|
id = f"cmpl-{shortuuid.random()}"
|
||||||
|
finish_stream_events = []
|
||||||
|
params["span_id"] = root_tracer.get_current_span_id()
|
||||||
|
curr_usage = UsageInfo()
|
||||||
|
last_usage = UsageInfo()
|
||||||
|
for text in request.prompt:
|
||||||
|
for i in range(request.n):
|
||||||
|
params["prompt"] = text
|
||||||
|
previous_text = ""
|
||||||
|
last_usage.prompt_tokens += curr_usage.prompt_tokens
|
||||||
|
last_usage.completion_tokens += curr_usage.completion_tokens
|
||||||
|
last_usage.total_tokens += curr_usage.total_tokens
|
||||||
|
|
||||||
|
async for model_output in worker_manager.generate_stream(params):
|
||||||
|
model_output: ModelOutput = model_output
|
||||||
|
if model_output.error_code != 0:
|
||||||
|
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||||
|
delta_text = decoded_unicode[len(previous_text) :]
|
||||||
|
previous_text = (
|
||||||
|
decoded_unicode
|
||||||
|
if len(decoded_unicode) > len(previous_text)
|
||||||
|
else previous_text
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(delta_text) == 0:
|
||||||
|
delta_text = None
|
||||||
|
|
||||||
|
choice_data = CompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
text=delta_text or "",
|
||||||
|
# TODO: logprobs
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=model_output.finish_reason,
|
||||||
|
)
|
||||||
|
if model_output.usage:
|
||||||
|
curr_usage = UsageInfo.model_validate(model_output.usage)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=last_usage.prompt_tokens
|
||||||
|
+ curr_usage.prompt_tokens,
|
||||||
|
total_tokens=last_usage.total_tokens
|
||||||
|
+ curr_usage.total_tokens,
|
||||||
|
completion_tokens=last_usage.completion_tokens
|
||||||
|
+ curr_usage.completion_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
usage = UsageInfo()
|
||||||
|
chunk = CompletionStreamResponse(
|
||||||
|
id=id,
|
||||||
|
object="text_completion",
|
||||||
|
choices=[choice_data],
|
||||||
|
model=request.model,
|
||||||
|
usage=UsageInfo.model_validate(usage),
|
||||||
|
)
|
||||||
|
if delta_text is None:
|
||||||
|
if model_output.finish_reason is not None:
|
||||||
|
finish_stream_events.append(chunk)
|
||||||
|
continue
|
||||||
|
json_data = model_to_json(
|
||||||
|
chunk, exclude_unset=True, ensure_ascii=False
|
||||||
|
)
|
||||||
|
yield f"data: {json_data}\n\n"
|
||||||
|
last_usage = curr_usage
|
||||||
|
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
||||||
|
for finish_chunk in finish_stream_events:
|
||||||
|
json_data = model_to_json(
|
||||||
|
finish_chunk, exclude_unset=True, ensure_ascii=False
|
||||||
|
)
|
||||||
|
yield f"data: {json_data}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
async def completion_generate(
|
||||||
|
self, request: CompletionRequest, params: Dict[str, Any]
|
||||||
|
):
|
||||||
|
worker_manager: WorkerManager = self.get_worker_manager()
|
||||||
|
choices = []
|
||||||
|
completions = []
|
||||||
|
for text in request.prompt:
|
||||||
|
for i in range(request.n):
|
||||||
|
params["prompt"] = text
|
||||||
|
model_output = asyncio.create_task(worker_manager.generate(params))
|
||||||
|
completions.append(model_output)
|
||||||
|
try:
|
||||||
|
all_tasks = await asyncio.gather(*completions)
|
||||||
|
except Exception as e:
|
||||||
|
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
||||||
|
usage = UsageInfo()
|
||||||
|
for i, model_output in enumerate(all_tasks):
|
||||||
|
model_output: ModelOutput = model_output
|
||||||
|
if model_output.error_code != 0:
|
||||||
|
return create_error_response(model_output.error_code, model_output.text)
|
||||||
|
choices.append(
|
||||||
|
CompletionResponseChoice(
|
||||||
|
index=i,
|
||||||
|
text=model_output.text,
|
||||||
|
finish_reason=model_output.finish_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if model_output.usage:
|
||||||
|
task_usage = UsageInfo.model_validate(model_output.usage)
|
||||||
|
for usage_key, usage_value in model_to_dict(task_usage).items():
|
||||||
|
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||||
|
return CompletionResponse(
|
||||||
|
model=request.model, choices=choices, usage=UsageInfo.model_validate(usage)
|
||||||
|
)
|
||||||
|
|
||||||
async def embeddings_generate(
|
async def embeddings_generate(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
@ -485,6 +626,57 @@ async def create_chat_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
||||||
|
async def create_completion(
|
||||||
|
request: CompletionRequest, api_server: APIServer = Depends(get_api_server)
|
||||||
|
):
|
||||||
|
await api_server.get_model_instances_or_raise(request.model)
|
||||||
|
error_check_ret = check_requests(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
if isinstance(request.prompt, str):
|
||||||
|
request.prompt = [request.prompt]
|
||||||
|
elif not isinstance(request.prompt, list):
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.VALIDATION_TYPE_ERROR,
|
||||||
|
"prompt must be a string or a list of strings",
|
||||||
|
)
|
||||||
|
elif isinstance(request.prompt, list) and not isinstance(request.prompt[0], str):
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.VALIDATION_TYPE_ERROR,
|
||||||
|
"prompt must be a string or a list of strings",
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"model": request.model,
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"chat_model": False,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"max_new_tokens": request.max_tokens,
|
||||||
|
"stop": request.stop,
|
||||||
|
"top_p": request.top_p,
|
||||||
|
"top_k": request.top_k,
|
||||||
|
"echo": request.echo,
|
||||||
|
"presence_penalty": request.presence_penalty,
|
||||||
|
"frequency_penalty": request.frequency_penalty,
|
||||||
|
"user": request.user,
|
||||||
|
# "use_beam_search": request.use_beam_search,
|
||||||
|
# "beam_size": request.beam_size,
|
||||||
|
}
|
||||||
|
trace_kwargs = {
|
||||||
|
"operation_name": "dbgpt.model.apiserver.create_completion",
|
||||||
|
"metadata": {k: v for k, v in params.items() if v},
|
||||||
|
}
|
||||||
|
if request.stream:
|
||||||
|
generator = api_server.completion_stream_generator(request, params)
|
||||||
|
trace_generator = root_tracer.wrapper_async_stream(generator, **trace_kwargs)
|
||||||
|
return StreamingResponse(trace_generator, media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
with root_tracer.start_span(**trace_kwargs):
|
||||||
|
params["span_id"] = root_tracer.get_current_span_id()
|
||||||
|
return await api_server.completion_generate(request, params)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
|
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
|
||||||
async def create_embeddings(
|
async def create_embeddings(
|
||||||
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
|
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
from dbgpt.core.interface.message import ModelMessage
|
from dbgpt.core.interface.message import ModelMessage
|
||||||
from dbgpt.model.base import WorkerApplyType
|
from dbgpt.model.base import WorkerApplyType
|
||||||
from dbgpt.model.parameter import WorkerType
|
from dbgpt.model.parameter import WorkerType
|
||||||
@ -10,10 +10,14 @@ WORKER_MANAGER_SERVICE_NAME = "WorkerManager"
|
|||||||
|
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
class PromptRequest(BaseModel):
|
||||||
messages: List[ModelMessage]
|
|
||||||
model: str
|
model: str
|
||||||
|
messages: List[ModelMessage] = Field(
|
||||||
|
default_factory=list, description="List of ModelMessage objects"
|
||||||
|
)
|
||||||
prompt: str = None
|
prompt: str = None
|
||||||
temperature: float = None
|
temperature: float = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
max_new_tokens: int = None
|
max_new_tokens: int = None
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
stop_token_ids: List[int] = []
|
stop_token_ids: List[int] = []
|
||||||
@ -26,6 +30,10 @@ class PromptRequest(BaseModel):
|
|||||||
"""Message version, default to v2"""
|
"""Message version, default to v2"""
|
||||||
context: Dict[str, Any] = None
|
context: Dict[str, Any] = None
|
||||||
"""Context information for the model"""
|
"""Context information for the model"""
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
chat_model: Optional[bool] = True
|
||||||
|
"""Whether to use chat model"""
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequest(BaseModel):
|
class EmbeddingsRequest(BaseModel):
|
||||||
|
@ -55,6 +55,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
model_type = self.llm_adapter.model_type()
|
model_type = self.llm_adapter.model_type()
|
||||||
self.param_cls = self.llm_adapter.model_param_class(model_type)
|
self.param_cls = self.llm_adapter.model_param_class(model_type)
|
||||||
self._support_async = self.llm_adapter.support_async()
|
self._support_async = self.llm_adapter.support_async()
|
||||||
|
self._support_generate_func = self.llm_adapter.support_generate_function()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
|
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
|
||||||
@ -85,7 +86,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
model_path=self.model_path,
|
model_path=self.model_path,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
if not model_params.device:
|
if hasattr(model_params, "device") and not model_params.device:
|
||||||
model_params.device = get_device()
|
model_params.device = get_device()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
|
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
|
||||||
@ -129,7 +130,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
if not self.model:
|
if not self.model:
|
||||||
logger.warn("Model has been stopped!!")
|
logger.warning("Model has been stopped!!")
|
||||||
return
|
return
|
||||||
del self.model
|
del self.model
|
||||||
del self.tokenizer
|
del self.tokenizer
|
||||||
@ -190,6 +191,37 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
def generate(self, params: Dict) -> ModelOutput:
|
def generate(self, params: Dict) -> ModelOutput:
|
||||||
"""Generate non stream result"""
|
"""Generate non stream result"""
|
||||||
output = None
|
output = None
|
||||||
|
if self._support_generate_func:
|
||||||
|
(
|
||||||
|
params,
|
||||||
|
model_context,
|
||||||
|
generate_stream_func,
|
||||||
|
model_span,
|
||||||
|
) = self._prepare_generate_stream(
|
||||||
|
params,
|
||||||
|
span_operation_name="DefaultModelWorker_call.generate_func",
|
||||||
|
is_stream=False,
|
||||||
|
)
|
||||||
|
previous_response = ""
|
||||||
|
last_metrics = ModelInferenceMetrics.create_metrics()
|
||||||
|
is_first_generate = True
|
||||||
|
output = generate_stream_func(
|
||||||
|
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||||
|
)
|
||||||
|
(
|
||||||
|
model_output,
|
||||||
|
incremental_output,
|
||||||
|
output_str,
|
||||||
|
current_metrics,
|
||||||
|
) = self._handle_output(
|
||||||
|
output,
|
||||||
|
previous_response,
|
||||||
|
model_context,
|
||||||
|
last_metrics,
|
||||||
|
is_first_generate,
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
|
else:
|
||||||
for out in self.generate_stream(params):
|
for out in self.generate_stream(params):
|
||||||
output = out
|
output = out
|
||||||
return output
|
return output
|
||||||
@ -277,12 +309,45 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
span.end(metadata={"error": output.to_dict()})
|
span.end(metadata={"error": output.to_dict()})
|
||||||
|
|
||||||
async def async_generate(self, params: Dict) -> ModelOutput:
|
async def async_generate(self, params: Dict) -> ModelOutput:
|
||||||
|
if self._support_generate_func:
|
||||||
|
(
|
||||||
|
params,
|
||||||
|
model_context,
|
||||||
|
generate_stream_func,
|
||||||
|
model_span,
|
||||||
|
) = self._prepare_generate_stream(
|
||||||
|
params,
|
||||||
|
span_operation_name="DefaultModelWorker_call.generate_func",
|
||||||
|
is_stream=False,
|
||||||
|
)
|
||||||
|
previous_response = ""
|
||||||
|
last_metrics = ModelInferenceMetrics.create_metrics()
|
||||||
|
is_first_generate = True
|
||||||
|
output = await generate_stream_func(
|
||||||
|
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||||
|
)
|
||||||
|
(
|
||||||
|
model_output,
|
||||||
|
incremental_output,
|
||||||
|
output_str,
|
||||||
|
current_metrics,
|
||||||
|
) = self._handle_output(
|
||||||
|
output,
|
||||||
|
previous_response,
|
||||||
|
model_context,
|
||||||
|
last_metrics,
|
||||||
|
is_first_generate,
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
|
else:
|
||||||
output = None
|
output = None
|
||||||
async for out in self.async_generate_stream(params):
|
async for out in self.async_generate_stream(params):
|
||||||
output = out
|
output = out
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _prepare_generate_stream(self, params: Dict, span_operation_name: str):
|
def _prepare_generate_stream(
|
||||||
|
self, params: Dict, span_operation_name: str, is_stream=True
|
||||||
|
):
|
||||||
params, model_context = self.llm_adapter.model_adaptation(
|
params, model_context = self.llm_adapter.model_adaptation(
|
||||||
params,
|
params,
|
||||||
self.model_name,
|
self.model_name,
|
||||||
@ -290,29 +355,48 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
prompt_template=self.ml.prompt_template,
|
prompt_template=self.ml.prompt_template,
|
||||||
)
|
)
|
||||||
stream_type = ""
|
|
||||||
if self.support_async():
|
if self.support_async():
|
||||||
generate_stream_func = self.llm_adapter.get_async_generate_stream_function(
|
if not is_stream and self.llm_adapter.support_generate_function():
|
||||||
|
func = self.llm_adapter.get_generate_function(
|
||||||
self.model, self.model_path
|
self.model, self.model_path
|
||||||
)
|
)
|
||||||
stream_type = "async "
|
func_type = "async generate"
|
||||||
logger.info(
|
logger.info(
|
||||||
"current generate stream function is asynchronous stream function"
|
"current generate function is asynchronous generate function"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
generate_stream_func = self.llm_adapter.get_generate_stream_function(
|
func = self.llm_adapter.get_async_generate_stream_function(
|
||||||
self.model, self.model_path
|
self.model, self.model_path
|
||||||
)
|
)
|
||||||
|
func_type = "async generate stream"
|
||||||
|
logger.info(
|
||||||
|
"current generate stream function is asynchronous generate stream function"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not is_stream and self.llm_adapter.support_generate_function():
|
||||||
|
func = self.llm_adapter.get_generate_function(
|
||||||
|
self.model, self.model_path
|
||||||
|
)
|
||||||
|
func_type = "generate"
|
||||||
|
logger.info(
|
||||||
|
"current generate function is synchronous generate function"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
func = self.llm_adapter.get_generate_stream_function(
|
||||||
|
self.model, self.model_path
|
||||||
|
)
|
||||||
|
func_type = "generate stream"
|
||||||
|
logger.info(
|
||||||
|
"current generate stream function is synchronous generate stream function"
|
||||||
|
)
|
||||||
str_prompt = params.get("prompt")
|
str_prompt = params.get("prompt")
|
||||||
if not str_prompt:
|
if not str_prompt:
|
||||||
str_prompt = params.get("string_prompt")
|
str_prompt = params.get("string_prompt")
|
||||||
print(
|
print(
|
||||||
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n"
|
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{func_type} output:\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
generate_stream_func_str_name = "{}.{}".format(
|
generate_func_str_name = "{}.{}".format(func.__module__, func.__name__)
|
||||||
generate_stream_func.__module__, generate_stream_func.__name__
|
|
||||||
)
|
|
||||||
|
|
||||||
span_params = {k: v for k, v in params.items()}
|
span_params = {k: v for k, v in params.items()}
|
||||||
if "messages" in span_params:
|
if "messages" in span_params:
|
||||||
@ -323,7 +407,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
metadata = {
|
metadata = {
|
||||||
"is_async_func": self.support_async(),
|
"is_async_func": self.support_async(),
|
||||||
"llm_adapter": str(self.llm_adapter),
|
"llm_adapter": str(self.llm_adapter),
|
||||||
"generate_stream_func": generate_stream_func_str_name,
|
"generate_func": generate_func_str_name,
|
||||||
}
|
}
|
||||||
metadata.update(span_params)
|
metadata.update(span_params)
|
||||||
metadata.update(model_context)
|
metadata.update(model_context)
|
||||||
@ -331,7 +415,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
|
|
||||||
model_span = root_tracer.start_span(span_operation_name, metadata=metadata)
|
model_span = root_tracer.start_span(span_operation_name, metadata=metadata)
|
||||||
|
|
||||||
return params, model_context, generate_stream_func, model_span
|
return params, model_context, func, model_span
|
||||||
|
|
||||||
def _handle_output(
|
def _handle_output(
|
||||||
self,
|
self,
|
||||||
|
@ -89,6 +89,8 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
json=params,
|
json=params,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
|
if response.status_code not in [200, 201]:
|
||||||
|
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||||
return ModelOutput(**response.json())
|
return ModelOutput(**response.json())
|
||||||
|
|
||||||
def count_token(self, prompt: str) -> int:
|
def count_token(self, prompt: str) -> int:
|
||||||
@ -106,6 +108,8 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
json={"prompt": prompt},
|
json={"prompt": prompt},
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
|
if response.status_code not in [200, 201]:
|
||||||
|
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
|
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||||
@ -123,6 +127,8 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
json=params,
|
json=params,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
|
if response.status_code not in [200, 201]:
|
||||||
|
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||||
return ModelMetadata.from_dict(response.json())
|
return ModelMetadata.from_dict(response.json())
|
||||||
|
|
||||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||||
@ -141,6 +147,8 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
json=params,
|
json=params,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
|
if response.status_code not in [200, 201]:
|
||||||
|
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
||||||
@ -156,6 +164,8 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
json=params,
|
json=params,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
|
if response.status_code not in [200, 201]:
|
||||||
|
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def _get_trace_headers(self):
|
def _get_trace_headers(self):
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
|
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
@ -62,11 +63,11 @@ class LlamaCppModel:
|
|||||||
self.model.__del__()
|
self.model.__del__()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(self, model_path, model_params: LlamaCppModelParameters):
|
def from_pretrained(cls, model_path, model_params: LlamaCppModelParameters):
|
||||||
Llama = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).Llama
|
Llama = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).Llama
|
||||||
LlamaCache = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).LlamaCache
|
LlamaCache = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).LlamaCache
|
||||||
|
|
||||||
result = self()
|
result = cls()
|
||||||
cache_capacity = 0
|
cache_capacity = 0
|
||||||
cache_capacity_str = model_params.cache_capacity
|
cache_capacity_str = model_params.cache_capacity
|
||||||
if cache_capacity_str is not None:
|
if cache_capacity_str is not None:
|
||||||
|
40
docs/docs/installation/advanced_usage/Llamacpp_server.md
Normal file
40
docs/docs/installation/advanced_usage/Llamacpp_server.md
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# LLama.cpp Server
|
||||||
|
|
||||||
|
DB-GPT supports native [llama.cpp server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md),
|
||||||
|
which supports concurrent requests and continuous batching inference.
|
||||||
|
|
||||||
|
|
||||||
|
## Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[llama_cpp_server]"
|
||||||
|
```
|
||||||
|
If you want to accelerate the inference speed, and you have a GPU, you can install the following dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CMAKE_ARGS="-DGGML_CUDA=ON" pip install -e ".[llama_cpp_server]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Download the model
|
||||||
|
|
||||||
|
Here, we use the `qwen2.5-0.5b-instruct` model as an example. You can download the model from the [Huggingface](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wget https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/qwen2.5-0.5b-instruct-q4_k_m.gguf?download=true -O /tmp/qwen2.5-0.5b-instruct-q4_k_m.gguf
|
||||||
|
````
|
||||||
|
|
||||||
|
## Modify configuration file
|
||||||
|
|
||||||
|
In the `.env` configuration file, modify the inference type of the model to start `llama.cpp` inference.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLM_MODEL=qwen2.5-0.5b-instruct
|
||||||
|
LLM_MODEL_PATH=/tmp/qwen2.5-0.5b-instruct-q4_k_m.gguf
|
||||||
|
MODEL_TYPE=llama_cpp_server
|
||||||
|
```
|
||||||
|
|
||||||
|
## Start the DB-GPT server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python dbgpt/app/dbgpt_server.py
|
||||||
|
```
|
@ -271,6 +271,10 @@ const sidebars = {
|
|||||||
type: 'doc',
|
type: 'doc',
|
||||||
id: 'installation/advanced_usage/vLLM_inference',
|
id: 'installation/advanced_usage/vLLM_inference',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
type: 'doc',
|
||||||
|
id: 'installation/advanced_usage/Llamacpp_server',
|
||||||
|
},
|
||||||
{
|
{
|
||||||
type: 'doc',
|
type: 'doc',
|
||||||
id: 'installation/advanced_usage/OpenAI_SDK_call',
|
id: 'installation/advanced_usage/OpenAI_SDK_call',
|
||||||
|
5
setup.py
5
setup.py
@ -556,7 +556,10 @@ def llama_cpp_requires():
|
|||||||
"""
|
"""
|
||||||
pip install "dbgpt[llama_cpp]"
|
pip install "dbgpt[llama_cpp]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["llama_cpp"] = ["llama-cpp-python"]
|
setup_spec.extras["llama_cpp_server"] = ["llama-cpp-server-py"]
|
||||||
|
setup_spec.extras["llama_cpp"] = setup_spec.extras["llama_cpp_server"] + [
|
||||||
|
"llama-cpp-python"
|
||||||
|
]
|
||||||
llama_cpp_python_cuda_requires()
|
llama_cpp_python_cuda_requires()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user