mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 19:40:13 +00:00
feat(model): Support llama.cpp server deploy (#2263)
This commit is contained in:
@@ -145,6 +145,14 @@ class LLMModelAdapter(ABC):
|
||||
"""Whether the loaded model supports asynchronous calls"""
|
||||
return False
|
||||
|
||||
def support_generate_function(self) -> bool:
|
||||
"""Whether support generate function, if it is False, we will use
|
||||
generate_stream function.
|
||||
|
||||
Sometimes, we need to use generate function to get the result of the model.
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the generate stream function of the model"""
|
||||
raise NotImplementedError
|
||||
@@ -153,6 +161,14 @@ class LLMModelAdapter(ABC):
|
||||
"""Get the asynchronous generate stream function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_generate_function(self, model, model_path: str):
|
||||
"""Get the generate function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_async_generate_function(self, model, model_path: str):
|
||||
"""Get the asynchronous generate function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> Optional[ConversationAdapter]:
|
||||
|
256
dbgpt/model/adapter/llama_cpp_adapter.py
Normal file
256
dbgpt/model/adapter/llama_cpp_adapter.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""llama.cpp server adapter.
|
||||
|
||||
See more details:
|
||||
https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md
|
||||
|
||||
**Features:**
|
||||
* LLM inference of F16 and quantized models on GPU and CPU
|
||||
* Parallel decoding with multi-user support
|
||||
* Continuous batching
|
||||
|
||||
The llama.cpp server is pure C++ server, we need to use the llama-cpp-server-py-core
|
||||
to interact with it.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.model.adapter.base import ConversationAdapter, LLMModelAdapter
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.parameter import ModelParameters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
from llama_cpp_server_py_core import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionStreamResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
LlamaCppServer,
|
||||
ServerConfig,
|
||||
ServerProcess,
|
||||
)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Failed to import llama_cpp_server_py_core, please install it first by `pip install llama-cpp-server-py-core`"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LlamaServerParameters(ServerConfig, ModelParameters):
|
||||
lora_files: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"help": "Lora files path"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.model_name:
|
||||
self.model_alias = self.model_name
|
||||
|
||||
if self.model_path and not self.model_file:
|
||||
self.model_file = self.model_path
|
||||
|
||||
if self.lora_files and isinstance(self.lora_files, str):
|
||||
self.lora_files = self.lora_files.split(",") # type: ignore
|
||||
elif not self.lora_files:
|
||||
self.lora_files = [] # type: ignore
|
||||
|
||||
if self.model_path:
|
||||
self.model_hf_repo = None
|
||||
self.model_hf_file = None
|
||||
device = self.device or get_device()
|
||||
if device and device == "cuda" and not self.n_gpu_layers:
|
||||
# Set n_gpu_layers to a large number to use all layers
|
||||
logger.info("Set n_gpu_layers to a large number to use all layers")
|
||||
self.n_gpu_layers = 1000000000
|
||||
|
||||
|
||||
class LLamaServerModelAdapter(LLMModelAdapter):
|
||||
def new_adapter(self, **kwargs) -> "LLamaServerModelAdapter":
|
||||
return self.__class__()
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.LLAMA_CPP_SERVER
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> Type[LlamaServerParameters]:
|
||||
return LlamaServerParameters
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> Optional[ConversationAdapter]:
|
||||
return None
|
||||
|
||||
def load_from_params(self, params: LlamaServerParameters):
|
||||
server = ServerProcess(params)
|
||||
server.start(300)
|
||||
model_server = LlamaCppServer(server, params)
|
||||
return model_server, model_server
|
||||
|
||||
def support_generate_function(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
return generate_stream
|
||||
|
||||
def get_generate_function(self, model, model_path: str):
|
||||
return generate
|
||||
|
||||
|
||||
def generate_stream(
|
||||
model: LlamaCppServer,
|
||||
tokenizer: LlamaCppServer,
|
||||
params: Dict,
|
||||
device: str,
|
||||
context_len: int,
|
||||
):
|
||||
chat_model = params.get("chat_model", True)
|
||||
if chat_model is None:
|
||||
chat_model = True
|
||||
if chat_model:
|
||||
for out in chat_generate_stream(model, params):
|
||||
yield out
|
||||
else:
|
||||
req = _build_completion_request(params, stream=True)
|
||||
# resp = model.stream_completion(req)
|
||||
text = ""
|
||||
for r in model.stream_completion(req):
|
||||
text += r.content
|
||||
timings = r.timings
|
||||
usage = {
|
||||
"completion_tokens": r.tokens_predicted,
|
||||
"prompt_tokens": r.tokens_evaluated,
|
||||
"total_tokens": r.tokens_predicted + r.tokens_evaluated,
|
||||
}
|
||||
if timings:
|
||||
logger.debug(f"Timings: {timings}")
|
||||
yield ModelOutput(
|
||||
text=text,
|
||||
error_code=0,
|
||||
finish_reason=_parse_finish_reason(r.stop_type),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def chat_generate_stream(
|
||||
model: LlamaCppServer,
|
||||
params: Dict,
|
||||
):
|
||||
req = _build_chat_completion_request(params, stream=True)
|
||||
text = ""
|
||||
for r in model.stream_chat_completion(req):
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
# Check for empty 'choices' issue in Azure GPT-4o responses
|
||||
if r.choices[0] is not None and r.choices[0].delta is None:
|
||||
continue
|
||||
content = r.choices[0].delta.content
|
||||
finish_reason = _parse_finish_reason(r.choices[0].finish_reason)
|
||||
|
||||
if content is not None:
|
||||
content = r.choices[0].delta.content
|
||||
text += content
|
||||
yield ModelOutput(
|
||||
text=text, error_code=0, finish_reason=finish_reason, usage=r.usage
|
||||
)
|
||||
elif text and content is None:
|
||||
# Last response is empty, return the text
|
||||
yield ModelOutput(
|
||||
text=text, error_code=0, finish_reason=finish_reason, usage=r.usage
|
||||
)
|
||||
|
||||
|
||||
def _build_chat_completion_request(
|
||||
params: Dict, stream: bool = True
|
||||
) -> ChatCompletionRequest:
|
||||
from dbgpt.model.proxy.llms.proxy_model import parse_model_request
|
||||
|
||||
# LLamaCppServer does not need to parse the model
|
||||
model_request = parse_model_request(params, "", stream=stream)
|
||||
return ChatCompletionRequest(
|
||||
messages=model_request.to_common_messages(),
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
top_k=params.get("top_k"),
|
||||
max_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
stream=stream,
|
||||
presence_penalty=params.get("presence_penalty"),
|
||||
frequency_penalty=params.get("frequency_penalty"),
|
||||
user=params.get("user_name"),
|
||||
)
|
||||
|
||||
|
||||
def _build_completion_request(params: Dict, stream: bool = True) -> CompletionRequest:
|
||||
from dbgpt.model.proxy.llms.proxy_model import parse_model_request
|
||||
|
||||
# LLamaCppServer does not need to parse the model
|
||||
model_request = parse_model_request(params, "", stream=stream)
|
||||
prompt = params.get("prompt")
|
||||
if not prompt and model_request.messages:
|
||||
prompt = model_request.messages[-1].content
|
||||
if not prompt:
|
||||
raise ValueError("Prompt is required for non-chat model")
|
||||
|
||||
return CompletionRequest(
|
||||
prompt=prompt,
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
top_k=params.get("top_k"),
|
||||
n_predict=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
stream=stream,
|
||||
presence_penalty=params.get("presence_penalty"),
|
||||
frequency_penalty=params.get("frequency_penalty"),
|
||||
)
|
||||
|
||||
|
||||
def generate(
|
||||
model: LlamaCppServer,
|
||||
tokenizer: LlamaCppServer,
|
||||
params: Dict,
|
||||
device: str,
|
||||
context_len: int,
|
||||
):
|
||||
chat_model = params.get("chat_model", True)
|
||||
if chat_model is None:
|
||||
chat_model = True
|
||||
|
||||
if chat_model:
|
||||
req = _build_chat_completion_request(params, stream=False)
|
||||
resp = model.chat_completion(req)
|
||||
if not resp.choices or not resp.choices[0].message:
|
||||
raise ValueError("Response can't be empty")
|
||||
content = resp.choices[0].message.content
|
||||
return ModelOutput(
|
||||
text=content,
|
||||
error_code=0,
|
||||
finish_reason=_parse_finish_reason(resp.choices[0].finish_reason),
|
||||
usage=resp.usage,
|
||||
)
|
||||
|
||||
else:
|
||||
req = _build_completion_request(params, stream=False)
|
||||
resp = model.completion(req)
|
||||
content = resp.content
|
||||
usage = {
|
||||
"completion_tokens": resp.tokens_predicted,
|
||||
"prompt_tokens": resp.tokens_evaluated,
|
||||
"total_tokens": resp.tokens_predicted + resp.tokens_evaluated,
|
||||
}
|
||||
return ModelOutput(
|
||||
text=content,
|
||||
error_code=0,
|
||||
finish_reason=_parse_finish_reason(resp.stop_type),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def _parse_finish_reason(finish_reason: Optional[str]) -> Optional[str]:
|
||||
if finish_reason == "limit":
|
||||
return "length"
|
||||
elif finish_reason is not None:
|
||||
return "stop"
|
||||
return None
|
@@ -130,6 +130,8 @@ class ModelLoader:
|
||||
return llm_adapter.load_from_params(model_params)
|
||||
elif model_type == ModelType.VLLM:
|
||||
return llm_adapter.load_from_params(model_params)
|
||||
elif model_type == ModelType.LLAMA_CPP_SERVER:
|
||||
return llm_adapter.load_from_params(model_params)
|
||||
else:
|
||||
raise Exception(f"Unkown model type {model_type}")
|
||||
|
||||
|
@@ -42,6 +42,13 @@ def get_llm_model_adapter(
|
||||
from dbgpt.model.adapter.vllm_adapter import VLLMModelAdapterWrapper
|
||||
|
||||
return VLLMModelAdapterWrapper(conv_factory)
|
||||
if model_type == ModelType.LLAMA_CPP_SERVER:
|
||||
logger.info(
|
||||
"Current model type is llama_cpp_server, return LLamaServerModelAdapter"
|
||||
)
|
||||
from dbgpt.model.adapter.llama_cpp_adapter import LLamaServerModelAdapter
|
||||
|
||||
return LLamaServerModelAdapter()
|
||||
|
||||
# Import NewHFChatModelAdapter for it can be registered
|
||||
from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter
|
||||
|
@@ -10,10 +10,11 @@ from dbgpt.util.parameter_utils import ParameterDescription
|
||||
|
||||
|
||||
class ModelType:
|
||||
""" "Type of model"""
|
||||
"""Type of model."""
|
||||
|
||||
HF = "huggingface"
|
||||
LLAMA_CPP = "llama.cpp"
|
||||
LLAMA_CPP_SERVER = "llama_cpp_server"
|
||||
PROXY = "proxy"
|
||||
VLLM = "vllm"
|
||||
# TODO, support more model type
|
||||
|
@@ -27,6 +27,11 @@ from dbgpt.core.schema.api import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
@@ -272,7 +277,13 @@ class APIServer(BaseComponent):
|
||||
worker_manager = self.get_worker_manager()
|
||||
id = f"chatcmpl-{shortuuid.random()}"
|
||||
finish_stream_events = []
|
||||
curr_usage = UsageInfo()
|
||||
last_usage = UsageInfo()
|
||||
for i in range(n):
|
||||
last_usage.prompt_tokens += curr_usage.prompt_tokens
|
||||
last_usage.completion_tokens += curr_usage.completion_tokens
|
||||
last_usage.total_tokens += curr_usage.total_tokens
|
||||
|
||||
# First chunk with role
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
@@ -280,7 +291,10 @@ class APIServer(BaseComponent):
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
id=id,
|
||||
choices=[choice_data],
|
||||
model=model_name,
|
||||
usage=last_usage,
|
||||
)
|
||||
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {json_data}\n\n"
|
||||
@@ -307,13 +321,28 @@ class APIServer(BaseComponent):
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
has_usage = False
|
||||
if model_output.usage:
|
||||
curr_usage = UsageInfo.model_validate(model_output.usage)
|
||||
has_usage = True
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=last_usage.prompt_tokens
|
||||
+ curr_usage.prompt_tokens,
|
||||
total_tokens=last_usage.total_tokens + curr_usage.total_tokens,
|
||||
completion_tokens=last_usage.completion_tokens
|
||||
+ curr_usage.completion_tokens,
|
||||
)
|
||||
else:
|
||||
has_usage = False
|
||||
usage = UsageInfo()
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
id=id, choices=[choice_data], model=model_name, usage=usage
|
||||
)
|
||||
if delta_text is None:
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
continue
|
||||
if not has_usage:
|
||||
continue
|
||||
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {json_data}\n\n"
|
||||
|
||||
@@ -363,6 +392,118 @@ class APIServer(BaseComponent):
|
||||
|
||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||
|
||||
async def completion_stream_generator(
|
||||
self, request: CompletionRequest, params: Dict
|
||||
):
|
||||
worker_manager = self.get_worker_manager()
|
||||
id = f"cmpl-{shortuuid.random()}"
|
||||
finish_stream_events = []
|
||||
params["span_id"] = root_tracer.get_current_span_id()
|
||||
curr_usage = UsageInfo()
|
||||
last_usage = UsageInfo()
|
||||
for text in request.prompt:
|
||||
for i in range(request.n):
|
||||
params["prompt"] = text
|
||||
previous_text = ""
|
||||
last_usage.prompt_tokens += curr_usage.prompt_tokens
|
||||
last_usage.completion_tokens += curr_usage.completion_tokens
|
||||
last_usage.total_tokens += curr_usage.total_tokens
|
||||
|
||||
async for model_output in worker_manager.generate_stream(params):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
delta_text = decoded_unicode[len(previous_text) :]
|
||||
previous_text = (
|
||||
decoded_unicode
|
||||
if len(decoded_unicode) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
|
||||
if len(delta_text) == 0:
|
||||
delta_text = None
|
||||
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text=delta_text or "",
|
||||
# TODO: logprobs
|
||||
logprobs=None,
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
if model_output.usage:
|
||||
curr_usage = UsageInfo.model_validate(model_output.usage)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=last_usage.prompt_tokens
|
||||
+ curr_usage.prompt_tokens,
|
||||
total_tokens=last_usage.total_tokens
|
||||
+ curr_usage.total_tokens,
|
||||
completion_tokens=last_usage.completion_tokens
|
||||
+ curr_usage.completion_tokens,
|
||||
)
|
||||
else:
|
||||
usage = UsageInfo()
|
||||
chunk = CompletionStreamResponse(
|
||||
id=id,
|
||||
object="text_completion",
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
usage=UsageInfo.model_validate(usage),
|
||||
)
|
||||
if delta_text is None:
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
continue
|
||||
json_data = model_to_json(
|
||||
chunk, exclude_unset=True, ensure_ascii=False
|
||||
)
|
||||
yield f"data: {json_data}\n\n"
|
||||
last_usage = curr_usage
|
||||
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
||||
for finish_chunk in finish_stream_events:
|
||||
json_data = model_to_json(
|
||||
finish_chunk, exclude_unset=True, ensure_ascii=False
|
||||
)
|
||||
yield f"data: {json_data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def completion_generate(
|
||||
self, request: CompletionRequest, params: Dict[str, Any]
|
||||
):
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
choices = []
|
||||
completions = []
|
||||
for text in request.prompt:
|
||||
for i in range(request.n):
|
||||
params["prompt"] = text
|
||||
model_output = asyncio.create_task(worker_manager.generate(params))
|
||||
completions.append(model_output)
|
||||
try:
|
||||
all_tasks = await asyncio.gather(*completions)
|
||||
except Exception as e:
|
||||
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
||||
usage = UsageInfo()
|
||||
for i, model_output in enumerate(all_tasks):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
return create_error_response(model_output.error_code, model_output.text)
|
||||
choices.append(
|
||||
CompletionResponseChoice(
|
||||
index=i,
|
||||
text=model_output.text,
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
)
|
||||
if model_output.usage:
|
||||
task_usage = UsageInfo.model_validate(model_output.usage)
|
||||
for usage_key, usage_value in model_to_dict(task_usage).items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
return CompletionResponse(
|
||||
model=request.model, choices=choices, usage=UsageInfo.model_validate(usage)
|
||||
)
|
||||
|
||||
async def embeddings_generate(
|
||||
self,
|
||||
model: str,
|
||||
@@ -485,6 +626,57 @@ async def create_chat_completion(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
||||
async def create_completion(
|
||||
request: CompletionRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
await api_server.get_model_instances_or_raise(request.model)
|
||||
error_check_ret = check_requests(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
if isinstance(request.prompt, str):
|
||||
request.prompt = [request.prompt]
|
||||
elif not isinstance(request.prompt, list):
|
||||
return create_error_response(
|
||||
ErrorCode.VALIDATION_TYPE_ERROR,
|
||||
"prompt must be a string or a list of strings",
|
||||
)
|
||||
elif isinstance(request.prompt, list) and not isinstance(request.prompt[0], str):
|
||||
return create_error_response(
|
||||
ErrorCode.VALIDATION_TYPE_ERROR,
|
||||
"prompt must be a string or a list of strings",
|
||||
)
|
||||
|
||||
params = {
|
||||
"model": request.model,
|
||||
"prompt": request.prompt,
|
||||
"chat_model": False,
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"stop": request.stop,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"echo": request.echo,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"user": request.user,
|
||||
# "use_beam_search": request.use_beam_search,
|
||||
# "beam_size": request.beam_size,
|
||||
}
|
||||
trace_kwargs = {
|
||||
"operation_name": "dbgpt.model.apiserver.create_completion",
|
||||
"metadata": {k: v for k, v in params.items() if v},
|
||||
}
|
||||
if request.stream:
|
||||
generator = api_server.completion_stream_generator(request, params)
|
||||
trace_generator = root_tracer.wrapper_async_stream(generator, **trace_kwargs)
|
||||
return StreamingResponse(trace_generator, media_type="text/event-stream")
|
||||
else:
|
||||
with root_tracer.start_span(**trace_kwargs):
|
||||
params["span_id"] = root_tracer.get_current_span_id()
|
||||
return await api_server.completion_generate(request, params)
|
||||
|
||||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
|
||||
async def create_embeddings(
|
||||
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
from dbgpt.model.base import WorkerApplyType
|
||||
from dbgpt.model.parameter import WorkerType
|
||||
@@ -10,10 +10,14 @@ WORKER_MANAGER_SERVICE_NAME = "WorkerManager"
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
messages: List[ModelMessage]
|
||||
model: str
|
||||
messages: List[ModelMessage] = Field(
|
||||
default_factory=list, description="List of ModelMessage objects"
|
||||
)
|
||||
prompt: str = None
|
||||
temperature: float = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
max_new_tokens: int = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: List[int] = []
|
||||
@@ -26,6 +30,10 @@ class PromptRequest(BaseModel):
|
||||
"""Message version, default to v2"""
|
||||
context: Dict[str, Any] = None
|
||||
"""Context information for the model"""
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
chat_model: Optional[bool] = True
|
||||
"""Whether to use chat model"""
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
|
@@ -55,6 +55,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
model_type = self.llm_adapter.model_type()
|
||||
self.param_cls = self.llm_adapter.model_param_class(model_type)
|
||||
self._support_async = self.llm_adapter.support_async()
|
||||
self._support_generate_func = self.llm_adapter.support_generate_function()
|
||||
|
||||
logger.info(
|
||||
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
|
||||
@@ -85,7 +86,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
model_path=self.model_path,
|
||||
model_type=model_type,
|
||||
)
|
||||
if not model_params.device:
|
||||
if hasattr(model_params, "device") and not model_params.device:
|
||||
model_params.device = get_device()
|
||||
logger.info(
|
||||
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
@@ -129,7 +130,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.model:
|
||||
logger.warn("Model has been stopped!!")
|
||||
logger.warning("Model has been stopped!!")
|
||||
return
|
||||
del self.model
|
||||
del self.tokenizer
|
||||
@@ -190,9 +191,40 @@ class DefaultModelWorker(ModelWorker):
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
output = None
|
||||
for out in self.generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
if self._support_generate_func:
|
||||
(
|
||||
params,
|
||||
model_context,
|
||||
generate_stream_func,
|
||||
model_span,
|
||||
) = self._prepare_generate_stream(
|
||||
params,
|
||||
span_operation_name="DefaultModelWorker_call.generate_func",
|
||||
is_stream=False,
|
||||
)
|
||||
previous_response = ""
|
||||
last_metrics = ModelInferenceMetrics.create_metrics()
|
||||
is_first_generate = True
|
||||
output = generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
)
|
||||
(
|
||||
model_output,
|
||||
incremental_output,
|
||||
output_str,
|
||||
current_metrics,
|
||||
) = self._handle_output(
|
||||
output,
|
||||
previous_response,
|
||||
model_context,
|
||||
last_metrics,
|
||||
is_first_generate,
|
||||
)
|
||||
return model_output
|
||||
else:
|
||||
for out in self.generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
return _try_to_count_token(prompt, self.tokenizer, self.model)
|
||||
@@ -277,12 +309,45 @@ class DefaultModelWorker(ModelWorker):
|
||||
span.end(metadata={"error": output.to_dict()})
|
||||
|
||||
async def async_generate(self, params: Dict) -> ModelOutput:
|
||||
output = None
|
||||
async for out in self.async_generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
if self._support_generate_func:
|
||||
(
|
||||
params,
|
||||
model_context,
|
||||
generate_stream_func,
|
||||
model_span,
|
||||
) = self._prepare_generate_stream(
|
||||
params,
|
||||
span_operation_name="DefaultModelWorker_call.generate_func",
|
||||
is_stream=False,
|
||||
)
|
||||
previous_response = ""
|
||||
last_metrics = ModelInferenceMetrics.create_metrics()
|
||||
is_first_generate = True
|
||||
output = await generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
)
|
||||
(
|
||||
model_output,
|
||||
incremental_output,
|
||||
output_str,
|
||||
current_metrics,
|
||||
) = self._handle_output(
|
||||
output,
|
||||
previous_response,
|
||||
model_context,
|
||||
last_metrics,
|
||||
is_first_generate,
|
||||
)
|
||||
return model_output
|
||||
else:
|
||||
output = None
|
||||
async for out in self.async_generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def _prepare_generate_stream(self, params: Dict, span_operation_name: str):
|
||||
def _prepare_generate_stream(
|
||||
self, params: Dict, span_operation_name: str, is_stream=True
|
||||
):
|
||||
params, model_context = self.llm_adapter.model_adaptation(
|
||||
params,
|
||||
self.model_name,
|
||||
@@ -290,29 +355,48 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.tokenizer,
|
||||
prompt_template=self.ml.prompt_template,
|
||||
)
|
||||
stream_type = ""
|
||||
if self.support_async():
|
||||
generate_stream_func = self.llm_adapter.get_async_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
stream_type = "async "
|
||||
logger.info(
|
||||
"current generate stream function is asynchronous stream function"
|
||||
)
|
||||
if not is_stream and self.llm_adapter.support_generate_function():
|
||||
func = self.llm_adapter.get_generate_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "async generate"
|
||||
logger.info(
|
||||
"current generate function is asynchronous generate function"
|
||||
)
|
||||
else:
|
||||
func = self.llm_adapter.get_async_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "async generate stream"
|
||||
logger.info(
|
||||
"current generate stream function is asynchronous generate stream function"
|
||||
)
|
||||
else:
|
||||
generate_stream_func = self.llm_adapter.get_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
if not is_stream and self.llm_adapter.support_generate_function():
|
||||
func = self.llm_adapter.get_generate_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "generate"
|
||||
logger.info(
|
||||
"current generate function is synchronous generate function"
|
||||
)
|
||||
else:
|
||||
func = self.llm_adapter.get_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "generate stream"
|
||||
logger.info(
|
||||
"current generate stream function is synchronous generate stream function"
|
||||
)
|
||||
str_prompt = params.get("prompt")
|
||||
if not str_prompt:
|
||||
str_prompt = params.get("string_prompt")
|
||||
print(
|
||||
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n"
|
||||
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{func_type} output:\n"
|
||||
)
|
||||
|
||||
generate_stream_func_str_name = "{}.{}".format(
|
||||
generate_stream_func.__module__, generate_stream_func.__name__
|
||||
)
|
||||
generate_func_str_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
|
||||
span_params = {k: v for k, v in params.items()}
|
||||
if "messages" in span_params:
|
||||
@@ -323,7 +407,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
metadata = {
|
||||
"is_async_func": self.support_async(),
|
||||
"llm_adapter": str(self.llm_adapter),
|
||||
"generate_stream_func": generate_stream_func_str_name,
|
||||
"generate_func": generate_func_str_name,
|
||||
}
|
||||
metadata.update(span_params)
|
||||
metadata.update(model_context)
|
||||
@@ -331,7 +415,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
|
||||
model_span = root_tracer.start_span(span_operation_name, metadata=metadata)
|
||||
|
||||
return params, model_context, generate_stream_func, model_span
|
||||
return params, model_context, func, model_span
|
||||
|
||||
def _handle_output(
|
||||
self,
|
||||
|
@@ -89,6 +89,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return ModelOutput(**response.json())
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
@@ -106,6 +108,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json={"prompt": prompt},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
@@ -123,6 +127,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return ModelMetadata.from_dict(response.json())
|
||||
|
||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
@@ -141,6 +147,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
@@ -156,6 +164,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def _get_trace_headers(self):
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict
|
||||
@@ -62,11 +63,11 @@ class LlamaCppModel:
|
||||
self.model.__del__()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, model_path, model_params: LlamaCppModelParameters):
|
||||
def from_pretrained(cls, model_path, model_params: LlamaCppModelParameters):
|
||||
Llama = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).Llama
|
||||
LlamaCache = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).LlamaCache
|
||||
|
||||
result = self()
|
||||
result = cls()
|
||||
cache_capacity = 0
|
||||
cache_capacity_str = model_params.cache_capacity
|
||||
if cache_capacity_str is not None:
|
||||
|
Reference in New Issue
Block a user